{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "bf68dfaf-37f5-463d-b72e-2120d27f0b0d",
   "metadata": {},
   "source": [
    "Imports for this Notebook:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 234,
   "id": "63528a79-842d-44ec-9109-29df1621cef8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# the basic imports\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "# this loads the data, apparently\n",
    "from torch.utils.data import DataLoader"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee773350-56e9-4601-a33e-b97c46a2cafc",
   "metadata": {},
   "source": [
    "Load and split the wine quality dataset with Pandas."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 243,
   "id": "d106a41b-e75c-4ca0-96e1-0e5b117992bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load data (only get rid of the Id column)\n",
    "wine_data = pd.read_csv(\"./WineQT.csv\", delimiter=\",\").drop(\"Id\", axis=1)\n",
    "\n",
    "# split the dataset into model and test subsets\n",
    "wine_model = wine_data.sample(frac=0.7, random_state=123)\n",
    "wine_test = wine_data.drop(wine_train.index)\n",
    "\n",
    "# further split the training set into train and validation\n",
    "wine_train = wine_model.sample(frac=0.7, random_state=123)\n",
    "wine_validate = wine_model.drop(wine_train.index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 244,
   "id": "29301f87-1919-48ae-8c9e-1a55f4fb1495",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "source: 1143\n",
      "train: 560 validate: 240 test: 343 == 1143\n"
     ]
    }
   ],
   "source": [
    "# verify the produced data size matches the source\n",
    "print(\"source:\", len(wine_data))\n",
    "print(\"train:\", len(wine_train), \"validate:\", len(wine_validate), \"test:\", len(wine_test), \"==\", len(wine_train + wine_validate + wine_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 245,
   "id": "3087bcdc-5c50-4b79-a2d8-f442a8d81825",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>fixed acidity</th>\n",
       "      <th>volatile acidity</th>\n",
       "      <th>citric acid</th>\n",
       "      <th>residual sugar</th>\n",
       "      <th>chlorides</th>\n",
       "      <th>free sulfur dioxide</th>\n",
       "      <th>total sulfur dioxide</th>\n",
       "      <th>density</th>\n",
       "      <th>pH</th>\n",
       "      <th>sulphates</th>\n",
       "      <th>alcohol</th>\n",
       "      <th>quality</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>711</th>\n",
       "      <td>7.5</td>\n",
       "      <td>0.71</td>\n",
       "      <td>0.00</td>\n",
       "      <td>1.6</td>\n",
       "      <td>0.092</td>\n",
       "      <td>22.0</td>\n",
       "      <td>31.0</td>\n",
       "      <td>0.99635</td>\n",
       "      <td>3.38</td>\n",
       "      <td>0.58</td>\n",
       "      <td>10.0</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>936</th>\n",
       "      <td>9.1</td>\n",
       "      <td>0.34</td>\n",
       "      <td>0.42</td>\n",
       "      <td>1.8</td>\n",
       "      <td>0.058</td>\n",
       "      <td>9.0</td>\n",
       "      <td>18.0</td>\n",
       "      <td>0.99392</td>\n",
       "      <td>3.18</td>\n",
       "      <td>0.55</td>\n",
       "      <td>11.4</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>683</th>\n",
       "      <td>10.4</td>\n",
       "      <td>0.26</td>\n",
       "      <td>0.48</td>\n",
       "      <td>1.9</td>\n",
       "      <td>0.066</td>\n",
       "      <td>6.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>0.99724</td>\n",
       "      <td>3.33</td>\n",
       "      <td>0.87</td>\n",
       "      <td>10.9</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>904</th>\n",
       "      <td>8.5</td>\n",
       "      <td>0.40</td>\n",
       "      <td>0.40</td>\n",
       "      <td>6.3</td>\n",
       "      <td>0.050</td>\n",
       "      <td>3.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>0.99566</td>\n",
       "      <td>3.28</td>\n",
       "      <td>0.56</td>\n",
       "      <td>12.0</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>198</th>\n",
       "      <td>8.9</td>\n",
       "      <td>0.40</td>\n",
       "      <td>0.32</td>\n",
       "      <td>5.6</td>\n",
       "      <td>0.087</td>\n",
       "      <td>10.0</td>\n",
       "      <td>47.0</td>\n",
       "      <td>0.99910</td>\n",
       "      <td>3.38</td>\n",
       "      <td>0.77</td>\n",
       "      <td>10.5</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1076</th>\n",
       "      <td>7.5</td>\n",
       "      <td>0.38</td>\n",
       "      <td>0.57</td>\n",
       "      <td>2.3</td>\n",
       "      <td>0.106</td>\n",
       "      <td>5.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>0.99605</td>\n",
       "      <td>3.36</td>\n",
       "      <td>0.55</td>\n",
       "      <td>11.4</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>709</th>\n",
       "      <td>8.9</td>\n",
       "      <td>0.32</td>\n",
       "      <td>0.31</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.088</td>\n",
       "      <td>12.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>0.99570</td>\n",
       "      <td>3.17</td>\n",
       "      <td>0.55</td>\n",
       "      <td>10.4</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>486</th>\n",
       "      <td>8.1</td>\n",
       "      <td>0.78</td>\n",
       "      <td>0.23</td>\n",
       "      <td>2.6</td>\n",
       "      <td>0.059</td>\n",
       "      <td>5.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>0.99700</td>\n",
       "      <td>3.37</td>\n",
       "      <td>0.56</td>\n",
       "      <td>11.3</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>182</th>\n",
       "      <td>7.7</td>\n",
       "      <td>0.41</td>\n",
       "      <td>0.76</td>\n",
       "      <td>1.8</td>\n",
       "      <td>0.611</td>\n",
       "      <td>8.0</td>\n",
       "      <td>45.0</td>\n",
       "      <td>0.99680</td>\n",
       "      <td>3.06</td>\n",
       "      <td>1.26</td>\n",
       "      <td>9.4</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>206</th>\n",
       "      <td>8.7</td>\n",
       "      <td>0.52</td>\n",
       "      <td>0.09</td>\n",
       "      <td>2.5</td>\n",
       "      <td>0.091</td>\n",
       "      <td>20.0</td>\n",
       "      <td>49.0</td>\n",
       "      <td>0.99760</td>\n",
       "      <td>3.34</td>\n",
       "      <td>0.86</td>\n",
       "      <td>10.6</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>560 rows × 12 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      fixed acidity  volatile acidity  citric acid  residual sugar  chlorides  \\\n",
       "711             7.5              0.71         0.00             1.6      0.092   \n",
       "936             9.1              0.34         0.42             1.8      0.058   \n",
       "683            10.4              0.26         0.48             1.9      0.066   \n",
       "904             8.5              0.40         0.40             6.3      0.050   \n",
       "198             8.9              0.40         0.32             5.6      0.087   \n",
       "...             ...               ...          ...             ...        ...   \n",
       "1076            7.5              0.38         0.57             2.3      0.106   \n",
       "709             8.9              0.32         0.31             2.0      0.088   \n",
       "486             8.1              0.78         0.23             2.6      0.059   \n",
       "182             7.7              0.41         0.76             1.8      0.611   \n",
       "206             8.7              0.52         0.09             2.5      0.091   \n",
       "\n",
       "      free sulfur dioxide  total sulfur dioxide  density    pH  sulphates  \\\n",
       "711                  22.0                  31.0  0.99635  3.38       0.58   \n",
       "936                   9.0                  18.0  0.99392  3.18       0.55   \n",
       "683                   6.0                  10.0  0.99724  3.33       0.87   \n",
       "904                   3.0                  10.0  0.99566  3.28       0.56   \n",
       "198                  10.0                  47.0  0.99910  3.38       0.77   \n",
       "...                   ...                   ...      ...   ...        ...   \n",
       "1076                  5.0                  12.0  0.99605  3.36       0.55   \n",
       "709                  12.0                  19.0  0.99570  3.17       0.55   \n",
       "486                   5.0                  15.0  0.99700  3.37       0.56   \n",
       "182                   8.0                  45.0  0.99680  3.06       1.26   \n",
       "206                  20.0                  49.0  0.99760  3.34       0.86   \n",
       "\n",
       "      alcohol  quality  \n",
       "711      10.0        6  \n",
       "936      11.4        5  \n",
       "683      10.9        6  \n",
       "904      12.0        4  \n",
       "198      10.5        7  \n",
       "...       ...      ...  \n",
       "1076     11.4        6  \n",
       "709      10.4        6  \n",
       "486      11.3        5  \n",
       "182       9.4        5  \n",
       "206      10.6        7  \n",
       "\n",
       "[560 rows x 12 columns]"
      ]
     },
     "execution_count": 245,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wine_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 246,
   "id": "241f2cdb-a7fb-4165-be11-95ca6510082b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>fixed acidity</th>\n",
       "      <th>volatile acidity</th>\n",
       "      <th>citric acid</th>\n",
       "      <th>residual sugar</th>\n",
       "      <th>chlorides</th>\n",
       "      <th>free sulfur dioxide</th>\n",
       "      <th>total sulfur dioxide</th>\n",
       "      <th>density</th>\n",
       "      <th>pH</th>\n",
       "      <th>sulphates</th>\n",
       "      <th>alcohol</th>\n",
       "      <th>quality</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>309</th>\n",
       "      <td>7.0</td>\n",
       "      <td>0.620</td>\n",
       "      <td>0.18</td>\n",
       "      <td>1.5</td>\n",
       "      <td>0.062</td>\n",
       "      <td>7.0</td>\n",
       "      <td>50.0</td>\n",
       "      <td>0.99510</td>\n",
       "      <td>3.08</td>\n",
       "      <td>0.60</td>\n",
       "      <td>9.3</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>528</th>\n",
       "      <td>8.3</td>\n",
       "      <td>0.760</td>\n",
       "      <td>0.29</td>\n",
       "      <td>4.2</td>\n",
       "      <td>0.075</td>\n",
       "      <td>12.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>0.99650</td>\n",
       "      <td>3.45</td>\n",
       "      <td>0.68</td>\n",
       "      <td>11.5</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>150</th>\n",
       "      <td>8.2</td>\n",
       "      <td>0.570</td>\n",
       "      <td>0.26</td>\n",
       "      <td>2.2</td>\n",
       "      <td>0.060</td>\n",
       "      <td>28.0</td>\n",
       "      <td>65.0</td>\n",
       "      <td>0.99590</td>\n",
       "      <td>3.30</td>\n",
       "      <td>0.43</td>\n",
       "      <td>10.1</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>278</th>\n",
       "      <td>6.6</td>\n",
       "      <td>0.735</td>\n",
       "      <td>0.02</td>\n",
       "      <td>7.9</td>\n",
       "      <td>0.122</td>\n",
       "      <td>68.0</td>\n",
       "      <td>124.0</td>\n",
       "      <td>0.99940</td>\n",
       "      <td>3.47</td>\n",
       "      <td>0.53</td>\n",
       "      <td>9.9</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>490</th>\n",
       "      <td>8.6</td>\n",
       "      <td>0.490</td>\n",
       "      <td>0.51</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.422</td>\n",
       "      <td>16.0</td>\n",
       "      <td>62.0</td>\n",
       "      <td>0.99790</td>\n",
       "      <td>3.03</td>\n",
       "      <td>1.17</td>\n",
       "      <td>9.0</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>896</th>\n",
       "      <td>10.4</td>\n",
       "      <td>0.430</td>\n",
       "      <td>0.50</td>\n",
       "      <td>2.3</td>\n",
       "      <td>0.068</td>\n",
       "      <td>13.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>0.99600</td>\n",
       "      <td>3.10</td>\n",
       "      <td>0.87</td>\n",
       "      <td>11.4</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>922</th>\n",
       "      <td>7.6</td>\n",
       "      <td>1.580</td>\n",
       "      <td>0.00</td>\n",
       "      <td>2.1</td>\n",
       "      <td>0.137</td>\n",
       "      <td>5.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>0.99476</td>\n",
       "      <td>3.50</td>\n",
       "      <td>0.40</td>\n",
       "      <td>10.9</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>219</th>\n",
       "      <td>8.4</td>\n",
       "      <td>0.650</td>\n",
       "      <td>0.60</td>\n",
       "      <td>2.1</td>\n",
       "      <td>0.112</td>\n",
       "      <td>12.0</td>\n",
       "      <td>90.0</td>\n",
       "      <td>0.99730</td>\n",
       "      <td>3.20</td>\n",
       "      <td>0.52</td>\n",
       "      <td>9.2</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>970</th>\n",
       "      <td>7.3</td>\n",
       "      <td>0.740</td>\n",
       "      <td>0.08</td>\n",
       "      <td>1.7</td>\n",
       "      <td>0.094</td>\n",
       "      <td>10.0</td>\n",
       "      <td>45.0</td>\n",
       "      <td>0.99576</td>\n",
       "      <td>3.24</td>\n",
       "      <td>0.50</td>\n",
       "      <td>9.8</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>288</th>\n",
       "      <td>8.8</td>\n",
       "      <td>0.520</td>\n",
       "      <td>0.34</td>\n",
       "      <td>2.7</td>\n",
       "      <td>0.087</td>\n",
       "      <td>24.0</td>\n",
       "      <td>122.0</td>\n",
       "      <td>0.99820</td>\n",
       "      <td>3.26</td>\n",
       "      <td>0.61</td>\n",
       "      <td>9.5</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>240 rows × 12 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     fixed acidity  volatile acidity  citric acid  residual sugar  chlorides  \\\n",
       "309            7.0             0.620         0.18             1.5      0.062   \n",
       "528            8.3             0.760         0.29             4.2      0.075   \n",
       "150            8.2             0.570         0.26             2.2      0.060   \n",
       "278            6.6             0.735         0.02             7.9      0.122   \n",
       "490            8.6             0.490         0.51             2.0      0.422   \n",
       "..             ...               ...          ...             ...        ...   \n",
       "896           10.4             0.430         0.50             2.3      0.068   \n",
       "922            7.6             1.580         0.00             2.1      0.137   \n",
       "219            8.4             0.650         0.60             2.1      0.112   \n",
       "970            7.3             0.740         0.08             1.7      0.094   \n",
       "288            8.8             0.520         0.34             2.7      0.087   \n",
       "\n",
       "     free sulfur dioxide  total sulfur dioxide  density    pH  sulphates  \\\n",
       "309                  7.0                  50.0  0.99510  3.08       0.60   \n",
       "528                 12.0                  16.0  0.99650  3.45       0.68   \n",
       "150                 28.0                  65.0  0.99590  3.30       0.43   \n",
       "278                 68.0                 124.0  0.99940  3.47       0.53   \n",
       "490                 16.0                  62.0  0.99790  3.03       1.17   \n",
       "..                   ...                   ...      ...   ...        ...   \n",
       "896                 13.0                  19.0  0.99600  3.10       0.87   \n",
       "922                  5.0                   9.0  0.99476  3.50       0.40   \n",
       "219                 12.0                  90.0  0.99730  3.20       0.52   \n",
       "970                 10.0                  45.0  0.99576  3.24       0.50   \n",
       "288                 24.0                 122.0  0.99820  3.26       0.61   \n",
       "\n",
       "     alcohol  quality  \n",
       "309      9.3        5  \n",
       "528     11.5        6  \n",
       "150     10.1        5  \n",
       "278      9.9        5  \n",
       "490      9.0        5  \n",
       "..       ...      ...  \n",
       "896     11.4        6  \n",
       "922     10.9        3  \n",
       "219      9.2        5  \n",
       "970      9.8        5  \n",
       "288      9.5        5  \n",
       "\n",
       "[240 rows x 12 columns]"
      ]
     },
     "execution_count": 246,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wine_validate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 247,
   "id": "b1e11d58-2415-4a17-9883-410e1bd15c21",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>fixed acidity</th>\n",
       "      <th>volatile acidity</th>\n",
       "      <th>citric acid</th>\n",
       "      <th>residual sugar</th>\n",
       "      <th>chlorides</th>\n",
       "      <th>free sulfur dioxide</th>\n",
       "      <th>total sulfur dioxide</th>\n",
       "      <th>density</th>\n",
       "      <th>pH</th>\n",
       "      <th>sulphates</th>\n",
       "      <th>alcohol</th>\n",
       "      <th>quality</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>7.8</td>\n",
       "      <td>0.88</td>\n",
       "      <td>0.00</td>\n",
       "      <td>2.6</td>\n",
       "      <td>0.098</td>\n",
       "      <td>25.0</td>\n",
       "      <td>67.0</td>\n",
       "      <td>0.99680</td>\n",
       "      <td>3.20</td>\n",
       "      <td>0.68</td>\n",
       "      <td>9.8</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>7.8</td>\n",
       "      <td>0.76</td>\n",
       "      <td>0.04</td>\n",
       "      <td>2.3</td>\n",
       "      <td>0.092</td>\n",
       "      <td>15.0</td>\n",
       "      <td>54.0</td>\n",
       "      <td>0.99700</td>\n",
       "      <td>3.26</td>\n",
       "      <td>0.65</td>\n",
       "      <td>9.8</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>11.2</td>\n",
       "      <td>0.28</td>\n",
       "      <td>0.56</td>\n",
       "      <td>1.9</td>\n",
       "      <td>0.075</td>\n",
       "      <td>17.0</td>\n",
       "      <td>60.0</td>\n",
       "      <td>0.99800</td>\n",
       "      <td>3.16</td>\n",
       "      <td>0.58</td>\n",
       "      <td>9.8</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>7.9</td>\n",
       "      <td>0.60</td>\n",
       "      <td>0.06</td>\n",
       "      <td>1.6</td>\n",
       "      <td>0.069</td>\n",
       "      <td>15.0</td>\n",
       "      <td>59.0</td>\n",
       "      <td>0.99640</td>\n",
       "      <td>3.30</td>\n",
       "      <td>0.46</td>\n",
       "      <td>9.4</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>7.8</td>\n",
       "      <td>0.58</td>\n",
       "      <td>0.02</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.073</td>\n",
       "      <td>9.0</td>\n",
       "      <td>18.0</td>\n",
       "      <td>0.99680</td>\n",
       "      <td>3.36</td>\n",
       "      <td>0.57</td>\n",
       "      <td>9.5</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1128</th>\n",
       "      <td>6.2</td>\n",
       "      <td>0.70</td>\n",
       "      <td>0.15</td>\n",
       "      <td>5.1</td>\n",
       "      <td>0.076</td>\n",
       "      <td>13.0</td>\n",
       "      <td>27.0</td>\n",
       "      <td>0.99622</td>\n",
       "      <td>3.54</td>\n",
       "      <td>0.60</td>\n",
       "      <td>11.9</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1133</th>\n",
       "      <td>6.7</td>\n",
       "      <td>0.32</td>\n",
       "      <td>0.44</td>\n",
       "      <td>2.4</td>\n",
       "      <td>0.061</td>\n",
       "      <td>24.0</td>\n",
       "      <td>34.0</td>\n",
       "      <td>0.99484</td>\n",
       "      <td>3.29</td>\n",
       "      <td>0.80</td>\n",
       "      <td>11.6</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1135</th>\n",
       "      <td>5.8</td>\n",
       "      <td>0.61</td>\n",
       "      <td>0.11</td>\n",
       "      <td>1.8</td>\n",
       "      <td>0.066</td>\n",
       "      <td>18.0</td>\n",
       "      <td>28.0</td>\n",
       "      <td>0.99483</td>\n",
       "      <td>3.55</td>\n",
       "      <td>0.66</td>\n",
       "      <td>10.9</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1137</th>\n",
       "      <td>5.4</td>\n",
       "      <td>0.74</td>\n",
       "      <td>0.09</td>\n",
       "      <td>1.7</td>\n",
       "      <td>0.089</td>\n",
       "      <td>16.0</td>\n",
       "      <td>26.0</td>\n",
       "      <td>0.99402</td>\n",
       "      <td>3.67</td>\n",
       "      <td>0.56</td>\n",
       "      <td>11.6</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1140</th>\n",
       "      <td>6.2</td>\n",
       "      <td>0.60</td>\n",
       "      <td>0.08</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.090</td>\n",
       "      <td>32.0</td>\n",
       "      <td>44.0</td>\n",
       "      <td>0.99490</td>\n",
       "      <td>3.45</td>\n",
       "      <td>0.58</td>\n",
       "      <td>10.5</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>343 rows × 12 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      fixed acidity  volatile acidity  citric acid  residual sugar  chlorides  \\\n",
       "1               7.8              0.88         0.00             2.6      0.098   \n",
       "2               7.8              0.76         0.04             2.3      0.092   \n",
       "3              11.2              0.28         0.56             1.9      0.075   \n",
       "6               7.9              0.60         0.06             1.6      0.069   \n",
       "8               7.8              0.58         0.02             2.0      0.073   \n",
       "...             ...               ...          ...             ...        ...   \n",
       "1128            6.2              0.70         0.15             5.1      0.076   \n",
       "1133            6.7              0.32         0.44             2.4      0.061   \n",
       "1135            5.8              0.61         0.11             1.8      0.066   \n",
       "1137            5.4              0.74         0.09             1.7      0.089   \n",
       "1140            6.2              0.60         0.08             2.0      0.090   \n",
       "\n",
       "      free sulfur dioxide  total sulfur dioxide  density    pH  sulphates  \\\n",
       "1                    25.0                  67.0  0.99680  3.20       0.68   \n",
       "2                    15.0                  54.0  0.99700  3.26       0.65   \n",
       "3                    17.0                  60.0  0.99800  3.16       0.58   \n",
       "6                    15.0                  59.0  0.99640  3.30       0.46   \n",
       "8                     9.0                  18.0  0.99680  3.36       0.57   \n",
       "...                   ...                   ...      ...   ...        ...   \n",
       "1128                 13.0                  27.0  0.99622  3.54       0.60   \n",
       "1133                 24.0                  34.0  0.99484  3.29       0.80   \n",
       "1135                 18.0                  28.0  0.99483  3.55       0.66   \n",
       "1137                 16.0                  26.0  0.99402  3.67       0.56   \n",
       "1140                 32.0                  44.0  0.99490  3.45       0.58   \n",
       "\n",
       "      alcohol  quality  \n",
       "1         9.8        5  \n",
       "2         9.8        5  \n",
       "3         9.8        6  \n",
       "6         9.4        5  \n",
       "8         9.5        7  \n",
       "...       ...      ...  \n",
       "1128     11.9        6  \n",
       "1133     11.6        7  \n",
       "1135     10.9        6  \n",
       "1137     11.6        6  \n",
       "1140     10.5        5  \n",
       "\n",
       "[343 rows x 12 columns]"
      ]
     },
     "execution_count": 247,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wine_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 248,
   "id": "88bc3d8e-10c7-4887-83c8-21e0274349a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# isolate results from features\n",
    "# now this is the hassle most of the time - convert whatever the input is to whatever it's expected to be\n",
    "# in our case, pandas data frames to pytorch tensors\n",
    "train_features = torch.tensor(wine_train.drop('quality', axis=1).values.astype(np.float32))\n",
    "train_target = torch.tensor(wine_train['quality'].values.astype(np.int64))\n",
    "validate_features = torch.tensor(wine_validate.drop('quality', axis=1).values.astype(np.float32))\n",
    "validate_target = torch.tensor(wine_validate['quality'].values.astype(np.int64))\n",
    "test_features = torch.tensor(wine_test.drop('quality', axis=1).values.astype(np.float32))\n",
    "test_target = torch.tensor(wine_test['quality'].values.astype(np.int64))\n",
    "\n",
    "train_data = torch.utils.data.TensorDataset(train_features, train_target)\n",
    "validate_data = torch.utils.data.TensorDataset(validate_features, validate_target)\n",
    "test_data = torch.utils.data.TensorDataset(test_features, test_target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 249,
   "id": "21a01ec6-44d9-4f50-b093-e80816345dd0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([ 7.5000,  0.7100,  0.0000,  1.6000,  0.0920, 22.0000, 31.0000,  0.9963,\n",
       "          3.3800,  0.5800, 10.0000]),\n",
       " tensor(6))"
      ]
     },
     "execution_count": 249,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 250,
   "id": "84198c1f-5a45-45ad-94f8-9e3de5a4efd5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([ 7.0000,  0.6200,  0.1800,  1.5000,  0.0620,  7.0000, 50.0000,  0.9951,\n",
       "          3.0800,  0.6000,  9.3000]),\n",
       " tensor(5))"
      ]
     },
     "execution_count": 250,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "validate_data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 251,
   "id": "c0685118-8baf-48d1-a554-a8b94a85d39a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([ 7.8000,  0.8800,  0.0000,  2.6000,  0.0980, 25.0000, 67.0000,  0.9968,\n",
       "          3.2000,  0.6800,  9.8000]),\n",
       " tensor(5))"
      ]
     },
     "execution_count": 251,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 252,
   "id": "c59d990a-3f85-4981-a288-8877ef1fa34c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create data loaders\n",
    "train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)\n",
    "validate_loader = torch.utils.data.DataLoader(validate_data, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)\n",
    "test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 224,
   "id": "e3aaa169-e086-4a19-865d-557c0c2aa210",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accelerated by mps\n"
     ]
    }
   ],
   "source": [
    "# check what device to use\n",
    "from torch.accelerator import is_available, current_accelerator\n",
    "if (is_available()):\n",
    "    device = current_accelerator().type\n",
    "    print(f\"accelerated by {device}\")\n",
    "else:\n",
    "    device = \"cpu\"\n",
    "    print(\"no accelerator\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 229,
   "id": "a1ee2d09-caed-4d39-ab51-cc27504753ef",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "**** DEFINING Linear/ReLU layer stack ****\n"
     ]
    }
   ],
   "source": [
    "# define the model\n",
    "print(\"**** DEFINING Linear/ReLU layer stack ****\")\n",
    "\n",
    "from torch import nn\n",
    "\n",
    "class SeqNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.linear_relu_stack = nn.Sequential(\n",
    "            nn.Linear(11, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, 10),\n",
    "            nn.Softmax(),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        logits = self.linear_relu_stack(x)\n",
    "        return logits\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 230,
   "id": "9e95ca59-e228-4b84-b3ec-f12c6e6dcb02",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SeqNN(\n",
      "  (linear_relu_stack): Sequential(\n",
      "    (0): Linear(in_features=11, out_features=512, bias=True)\n",
      "    (1): ReLU()\n",
      "    (2): Linear(in_features=512, out_features=10, bias=True)\n",
      "    (3): Softmax(dim=None)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# show what the model looks like\n",
    "model = SeqNN() #.to(device)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 227,
   "id": "6df0edaa-bf1a-4ff5-ae41-2f1c03811c7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define the loss function and optimizer\n",
    "loss_fn = torch.nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.90)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 257,
   "id": "102c4fc5-0798-40c5-bb32-9c87fc7707d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_epoch(ep_num):\n",
    "    running_loss = 0.\n",
    "    last_loss = 0.\n",
    "\n",
    "    report_every = 5\n",
    "\n",
    "    for idx, batch in enumerate(train_loader):\n",
    "        # unpack the next batch\n",
    "        features, labels = batch\n",
    "    \n",
    "        # ensure features are on the accelerator device (if any)\n",
    "        features.to(device)\n",
    "        \n",
    "        # zero the optimizer gradients\n",
    "        optimizer.zero_grad()\n",
    "    \n",
    "        # feed it to model\n",
    "        outputs = model(features)\n",
    "    \n",
    "        # calculate the loss against the expected label\n",
    "        loss = loss_fn(outputs, labels)\n",
    "        loss.backward()\n",
    "    \n",
    "        # feed the new loss info to optimizer\n",
    "        optimizer.step()\n",
    "    \n",
    "        # calculate whether this is an improvement or a degradation\n",
    "        running_loss += loss.item()\n",
    "\n",
    "        # report loss uppon every 20 items\n",
    "        if idx % report_every == (report_every - 1):\n",
    "            last_loss = running_loss / report_every\n",
    "            print(\"  batch {} loss: {}\".format(idx + 1, last_loss))\n",
    "            tb_x = ep_num * len(train_loader) + idx + 1\n",
    "            print(\"  loss/train: {}/{}\".format(last_loss, tb_x))\n",
    "            running_loss = 0.\n",
    "\n",
    "    return last_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 258,
   "id": "cdf2c8ae-d024-4687-b1df-23804f3d9816",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "**** TRAINING NN on supplied data ****\n",
      "EPOCH 1\n",
      "  batch 5 loss: 2.4407766819000245\n",
      "  loss/train: 2.4407766819000245/5\n",
      "  batch 10 loss: 2.451565170288086\n",
      "  loss/train: 2.451565170288086/10\n",
      "  batch 15 loss: 2.43599009513855\n",
      "  loss/train: 2.43599009513855/15\n",
      "LOSS train 2.43599009513855 / valid 2.4329328536987305\n",
      "EPOCH 2\n",
      "  batch 5 loss: 2.444344425201416\n",
      "  loss/train: 2.444344425201416/23\n",
      "  batch 10 loss: 2.4506516456604004\n",
      "  loss/train: 2.4506516456604004/28\n",
      "  batch 15 loss: 2.4280431270599365\n",
      "  loss/train: 2.4280431270599365/33\n",
      "LOSS train 2.4280431270599365 / valid 2.4371285438537598\n",
      "EPOCH 3\n",
      "  batch 5 loss: 2.444630241394043\n",
      "  loss/train: 2.444630241394043/41\n",
      "  batch 10 loss: 2.4294994354248045\n",
      "  loss/train: 2.4294994354248045/46\n",
      "  batch 15 loss: 2.4486732959747313\n",
      "  loss/train: 2.4486732959747313/51\n",
      "LOSS train 2.4486732959747313 / valid 2.437175989151001\n",
      "EPOCH 4\n",
      "  batch 5 loss: 2.449464464187622\n",
      "  loss/train: 2.449464464187622/59\n",
      "  batch 10 loss: 2.435447835922241\n",
      "  loss/train: 2.435447835922241/64\n",
      "  batch 15 loss: 2.4380492210388183\n",
      "  loss/train: 2.4380492210388183/69\n",
      "LOSS train 2.4380492210388183 / valid 2.437028408050537\n",
      "EPOCH 5\n",
      "  batch 5 loss: 2.4409915447235107\n",
      "  loss/train: 2.4409915447235107/77\n",
      "  batch 10 loss: 2.443865346908569\n",
      "  loss/train: 2.443865346908569/82\n",
      "  batch 15 loss: 2.4377002716064453\n",
      "  loss/train: 2.4377002716064453/87\n",
      "LOSS train 2.4377002716064453 / valid 2.4371585845947266\n"
     ]
    }
   ],
   "source": [
    "# train the model\n",
    "print(\"**** TRAINING NN on supplied data ****\")\n",
    "\n",
    "epoch_no = 0\n",
    "num_epochs = 5\n",
    "best_vloss = 1000000\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    print(\"EPOCH\", epoch + 1)\n",
    "\n",
    "    model.train(True)\n",
    "    avg_loss = train_epoch(epoch_no)\n",
    "\n",
    "    running_vloss = 0.0\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        for idx, vdata in enumerate(validate_loader):\n",
    "            vfeatures, vlabels = vdata\n",
    "            voutputs = model(vfeatures)\n",
    "            vloss = loss_fn(voutputs, vlabels)\n",
    "            running_vloss += vloss\n",
    "\n",
    "    avg_vloss = running_vloss / (idx + 1)\n",
    "    print(\"LOSS train {} / valid {}\".format(avg_loss, avg_vloss))\n",
    "\n",
    "    if avg_vloss < best_vloss:\n",
    "        best_vloss = avg_vloss\n",
    "\n",
    "    epoch_no += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 259,
   "id": "4ef4f484-5664-45dd-a5aa-2909276cedc8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SeqNN(\n",
       "  (linear_relu_stack): Sequential(\n",
       "    (0): Linear(in_features=11, out_features=512, bias=True)\n",
       "    (1): ReLU()\n",
       "    (2): Linear(in_features=512, out_features=10, bias=True)\n",
       "    (3): Softmax(dim=None)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 259,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ff5d2c3-e979-434f-a6ac-355c68832952",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch-26",
   "language": "python",
   "name": "pytorch-26"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}