瀏覽代碼

reorg, add to README

Grega Bremec 2 周之前
父節點
當前提交
43991a6863
共有 7 個文件被更改,包括 1541 次插入5 次删除
  1. 7 5
      README.adoc
  2. 1262 0
      code/wine-pytorch.ipynb
  3. 201 0
      code/wine-sklearn.ipynb
  4. 71 0
      code/wine-sklearn.py
  5. 0 0
      envs/env-pytorch-26.yml
  6. 0 0
      envs/env-sklearn-16.yml
  7. 0 0
      envs/env-tf-216.yml

+ 7 - 5
README.adoc

@@ -18,7 +18,7 @@ Just looking at Python's most popular ML frameworks:
 
 Additionally, https://keras.io/[Keras] is a multi-framework ML frontend that can work with TensorFlow, PyTorch, SciKit-Learn, JAX, and others.
 
-The problem is this is the middle layer nowadays, not even the topmost any more.
+The challenge is this is the middle layer nowadays, not even the topmost any more.
 
 These frameworks all come with their own tools to make the jobs of working with data and training models easier:
 
@@ -64,7 +64,9 @@ MLOps is another can of worms:
 * https://www.kubeflow.org/[KubeFlow] for Kubernetes-related integrations (included in RHOAI)
 * https://github.com/elyra-ai/elyra[Elyra], integrating with JupyterLab Notebooks
 
-Don't worry - start small and when the sandbox becomes too small, look around to see what can extend it.
+Don't worry - start small and when you outgrow the sandbox, look at your project and listen to what sounds like the most interesting direction to go in at the moment, and then look around to see what tools you can learn about to expand.
+
+We all needed years to get our sense of direction in this story. There are no shortcuts, just fun along the way.
 
 === Get Some Samples ===
 
@@ -76,8 +78,8 @@ The data you need for those is typically very different. Try figuring out why an
 
 Nowadays, https://www.kaggle.com/datasets[Kaggle] has some great example datasets (requires a free account).
 
-* https://casas.wsu.edu/datasets/[CASAS-HAR] (Human Activity Recognition from Continuous Ambient Sensor Data) is a very flexible regression dataset (available for https://www.kaggle.com/datasets/utkarshx27/ambient-sensor-based-human-activity-recognition[download from Kaggle]).
-* MNIST has two excellent classification datasets: handwritten numbers and fashion items - they are a great start for classification - almost every framework includes them. Both are also available at Kaggle - https://www.kaggle.com/datasets/hojjatk/mnist-dataset[numbers] and https://www.kaggle.com/datasets/zalando-research/fashionmnist[fashion].
+* https://www.nist.gov/el/ammt-temps/datasets[NIST] has two excellent image classification datasets: handwritten numbers and fashion items - they are a great start for classification - almost every framework includes them. Both are also available at Kaggle - https://www.kaggle.com/datasets/hojjatk/mnist-dataset[numbers] and https://www.kaggle.com/datasets/zalando-research/fashionmnist[fashion].
+* https://casas.wsu.edu/datasets/[CASAS] has a _Human Activity Recognition from Continuous Ambient Sensor Data_ dataset - it is a very flexible regression dataset offering tons of opportunities (available for https://www.kaggle.com/datasets/utkarshx27/ambient-sensor-based-human-activity-recognition[download from Kaggle]).
 
 == Getting Running ==
 
@@ -335,7 +337,7 @@ JupyterLab Notebooks were designed to resolve those problems by being something
 
 Not only that - you can define different Python kernels which belong to various Conda environments, in the same JupyterLab instance, and simply associate your notebooks with the kernel they need, so that they can run in whichever environment you want them to.
 
-If you want to use themm, the best way to do it is to install `jupyterlab` into the base environment.
+If you want to use them, the best way to do it is to install `jupyterlab` into the base environment.
 
 [subs="+quotes"]
 ----

+ 1262 - 0
code/wine-pytorch.ipynb

@@ -0,0 +1,1262 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "bf68dfaf-37f5-463d-b72e-2120d27f0b0d",
+   "metadata": {},
+   "source": [
+    "Imports for this Notebook:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 234,
+   "id": "63528a79-842d-44ec-9109-29df1621cef8",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# the basic imports\n",
+    "import matplotlib.pyplot as plt\n",
+    "import numpy as np\n",
+    "import pandas as pd\n",
+    "import torch\n",
+    "\n",
+    "# this loads the data, apparently\n",
+    "from torch.utils.data import DataLoader"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "ee773350-56e9-4601-a33e-b97c46a2cafc",
+   "metadata": {},
+   "source": [
+    "Load and split the wine quality dataset with Pandas."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 243,
+   "id": "d106a41b-e75c-4ca0-96e1-0e5b117992bf",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# load data (only get rid of the Id column)\n",
+    "wine_data = pd.read_csv(\"./WineQT.csv\", delimiter=\",\").drop(\"Id\", axis=1)\n",
+    "\n",
+    "# split the dataset into model and test subsets\n",
+    "wine_model = wine_data.sample(frac=0.7, random_state=123)\n",
+    "wine_test = wine_data.drop(wine_train.index)\n",
+    "\n",
+    "# further split the training set into train and validation\n",
+    "wine_train = wine_model.sample(frac=0.7, random_state=123)\n",
+    "wine_validate = wine_model.drop(wine_train.index)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 244,
+   "id": "29301f87-1919-48ae-8c9e-1a55f4fb1495",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "source: 1143\n",
+      "train: 560 validate: 240 test: 343 == 1143\n"
+     ]
+    }
+   ],
+   "source": [
+    "# verify the produced data size matches the source\n",
+    "print(\"source:\", len(wine_data))\n",
+    "print(\"train:\", len(wine_train), \"validate:\", len(wine_validate), \"test:\", len(wine_test), \"==\", len(wine_train + wine_validate + wine_test))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 245,
+   "id": "3087bcdc-5c50-4b79-a2d8-f442a8d81825",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>fixed acidity</th>\n",
+       "      <th>volatile acidity</th>\n",
+       "      <th>citric acid</th>\n",
+       "      <th>residual sugar</th>\n",
+       "      <th>chlorides</th>\n",
+       "      <th>free sulfur dioxide</th>\n",
+       "      <th>total sulfur dioxide</th>\n",
+       "      <th>density</th>\n",
+       "      <th>pH</th>\n",
+       "      <th>sulphates</th>\n",
+       "      <th>alcohol</th>\n",
+       "      <th>quality</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>711</th>\n",
+       "      <td>7.5</td>\n",
+       "      <td>0.71</td>\n",
+       "      <td>0.00</td>\n",
+       "      <td>1.6</td>\n",
+       "      <td>0.092</td>\n",
+       "      <td>22.0</td>\n",
+       "      <td>31.0</td>\n",
+       "      <td>0.99635</td>\n",
+       "      <td>3.38</td>\n",
+       "      <td>0.58</td>\n",
+       "      <td>10.0</td>\n",
+       "      <td>6</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>936</th>\n",
+       "      <td>9.1</td>\n",
+       "      <td>0.34</td>\n",
+       "      <td>0.42</td>\n",
+       "      <td>1.8</td>\n",
+       "      <td>0.058</td>\n",
+       "      <td>9.0</td>\n",
+       "      <td>18.0</td>\n",
+       "      <td>0.99392</td>\n",
+       "      <td>3.18</td>\n",
+       "      <td>0.55</td>\n",
+       "      <td>11.4</td>\n",
+       "      <td>5</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>683</th>\n",
+       "      <td>10.4</td>\n",
+       "      <td>0.26</td>\n",
+       "      <td>0.48</td>\n",
+       "      <td>1.9</td>\n",
+       "      <td>0.066</td>\n",
+       "      <td>6.0</td>\n",
+       "      <td>10.0</td>\n",
+       "      <td>0.99724</td>\n",
+       "      <td>3.33</td>\n",
+       "      <td>0.87</td>\n",
+       "      <td>10.9</td>\n",
+       "      <td>6</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>904</th>\n",
+       "      <td>8.5</td>\n",
+       "      <td>0.40</td>\n",
+       "      <td>0.40</td>\n",
+       "      <td>6.3</td>\n",
+       "      <td>0.050</td>\n",
+       "      <td>3.0</td>\n",
+       "      <td>10.0</td>\n",
+       "      <td>0.99566</td>\n",
+       "      <td>3.28</td>\n",
+       "      <td>0.56</td>\n",
+       "      <td>12.0</td>\n",
+       "      <td>4</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>198</th>\n",
+       "      <td>8.9</td>\n",
+       "      <td>0.40</td>\n",
+       "      <td>0.32</td>\n",
+       "      <td>5.6</td>\n",
+       "      <td>0.087</td>\n",
+       "      <td>10.0</td>\n",
+       "      <td>47.0</td>\n",
+       "      <td>0.99910</td>\n",
+       "      <td>3.38</td>\n",
+       "      <td>0.77</td>\n",
+       "      <td>10.5</td>\n",
+       "      <td>7</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>...</th>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1076</th>\n",
+       "      <td>7.5</td>\n",
+       "      <td>0.38</td>\n",
+       "      <td>0.57</td>\n",
+       "      <td>2.3</td>\n",
+       "      <td>0.106</td>\n",
+       "      <td>5.0</td>\n",
+       "      <td>12.0</td>\n",
+       "      <td>0.99605</td>\n",
+       "      <td>3.36</td>\n",
+       "      <td>0.55</td>\n",
+       "      <td>11.4</td>\n",
+       "      <td>6</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>709</th>\n",
+       "      <td>8.9</td>\n",
+       "      <td>0.32</td>\n",
+       "      <td>0.31</td>\n",
+       "      <td>2.0</td>\n",
+       "      <td>0.088</td>\n",
+       "      <td>12.0</td>\n",
+       "      <td>19.0</td>\n",
+       "      <td>0.99570</td>\n",
+       "      <td>3.17</td>\n",
+       "      <td>0.55</td>\n",
+       "      <td>10.4</td>\n",
+       "      <td>6</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>486</th>\n",
+       "      <td>8.1</td>\n",
+       "      <td>0.78</td>\n",
+       "      <td>0.23</td>\n",
+       "      <td>2.6</td>\n",
+       "      <td>0.059</td>\n",
+       "      <td>5.0</td>\n",
+       "      <td>15.0</td>\n",
+       "      <td>0.99700</td>\n",
+       "      <td>3.37</td>\n",
+       "      <td>0.56</td>\n",
+       "      <td>11.3</td>\n",
+       "      <td>5</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>182</th>\n",
+       "      <td>7.7</td>\n",
+       "      <td>0.41</td>\n",
+       "      <td>0.76</td>\n",
+       "      <td>1.8</td>\n",
+       "      <td>0.611</td>\n",
+       "      <td>8.0</td>\n",
+       "      <td>45.0</td>\n",
+       "      <td>0.99680</td>\n",
+       "      <td>3.06</td>\n",
+       "      <td>1.26</td>\n",
+       "      <td>9.4</td>\n",
+       "      <td>5</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>206</th>\n",
+       "      <td>8.7</td>\n",
+       "      <td>0.52</td>\n",
+       "      <td>0.09</td>\n",
+       "      <td>2.5</td>\n",
+       "      <td>0.091</td>\n",
+       "      <td>20.0</td>\n",
+       "      <td>49.0</td>\n",
+       "      <td>0.99760</td>\n",
+       "      <td>3.34</td>\n",
+       "      <td>0.86</td>\n",
+       "      <td>10.6</td>\n",
+       "      <td>7</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "<p>560 rows × 12 columns</p>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "      fixed acidity  volatile acidity  citric acid  residual sugar  chlorides  \\\n",
+       "711             7.5              0.71         0.00             1.6      0.092   \n",
+       "936             9.1              0.34         0.42             1.8      0.058   \n",
+       "683            10.4              0.26         0.48             1.9      0.066   \n",
+       "904             8.5              0.40         0.40             6.3      0.050   \n",
+       "198             8.9              0.40         0.32             5.6      0.087   \n",
+       "...             ...               ...          ...             ...        ...   \n",
+       "1076            7.5              0.38         0.57             2.3      0.106   \n",
+       "709             8.9              0.32         0.31             2.0      0.088   \n",
+       "486             8.1              0.78         0.23             2.6      0.059   \n",
+       "182             7.7              0.41         0.76             1.8      0.611   \n",
+       "206             8.7              0.52         0.09             2.5      0.091   \n",
+       "\n",
+       "      free sulfur dioxide  total sulfur dioxide  density    pH  sulphates  \\\n",
+       "711                  22.0                  31.0  0.99635  3.38       0.58   \n",
+       "936                   9.0                  18.0  0.99392  3.18       0.55   \n",
+       "683                   6.0                  10.0  0.99724  3.33       0.87   \n",
+       "904                   3.0                  10.0  0.99566  3.28       0.56   \n",
+       "198                  10.0                  47.0  0.99910  3.38       0.77   \n",
+       "...                   ...                   ...      ...   ...        ...   \n",
+       "1076                  5.0                  12.0  0.99605  3.36       0.55   \n",
+       "709                  12.0                  19.0  0.99570  3.17       0.55   \n",
+       "486                   5.0                  15.0  0.99700  3.37       0.56   \n",
+       "182                   8.0                  45.0  0.99680  3.06       1.26   \n",
+       "206                  20.0                  49.0  0.99760  3.34       0.86   \n",
+       "\n",
+       "      alcohol  quality  \n",
+       "711      10.0        6  \n",
+       "936      11.4        5  \n",
+       "683      10.9        6  \n",
+       "904      12.0        4  \n",
+       "198      10.5        7  \n",
+       "...       ...      ...  \n",
+       "1076     11.4        6  \n",
+       "709      10.4        6  \n",
+       "486      11.3        5  \n",
+       "182       9.4        5  \n",
+       "206      10.6        7  \n",
+       "\n",
+       "[560 rows x 12 columns]"
+      ]
+     },
+     "execution_count": 245,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "wine_train"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 246,
+   "id": "241f2cdb-a7fb-4165-be11-95ca6510082b",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>fixed acidity</th>\n",
+       "      <th>volatile acidity</th>\n",
+       "      <th>citric acid</th>\n",
+       "      <th>residual sugar</th>\n",
+       "      <th>chlorides</th>\n",
+       "      <th>free sulfur dioxide</th>\n",
+       "      <th>total sulfur dioxide</th>\n",
+       "      <th>density</th>\n",
+       "      <th>pH</th>\n",
+       "      <th>sulphates</th>\n",
+       "      <th>alcohol</th>\n",
+       "      <th>quality</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>309</th>\n",
+       "      <td>7.0</td>\n",
+       "      <td>0.620</td>\n",
+       "      <td>0.18</td>\n",
+       "      <td>1.5</td>\n",
+       "      <td>0.062</td>\n",
+       "      <td>7.0</td>\n",
+       "      <td>50.0</td>\n",
+       "      <td>0.99510</td>\n",
+       "      <td>3.08</td>\n",
+       "      <td>0.60</td>\n",
+       "      <td>9.3</td>\n",
+       "      <td>5</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>528</th>\n",
+       "      <td>8.3</td>\n",
+       "      <td>0.760</td>\n",
+       "      <td>0.29</td>\n",
+       "      <td>4.2</td>\n",
+       "      <td>0.075</td>\n",
+       "      <td>12.0</td>\n",
+       "      <td>16.0</td>\n",
+       "      <td>0.99650</td>\n",
+       "      <td>3.45</td>\n",
+       "      <td>0.68</td>\n",
+       "      <td>11.5</td>\n",
+       "      <td>6</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>150</th>\n",
+       "      <td>8.2</td>\n",
+       "      <td>0.570</td>\n",
+       "      <td>0.26</td>\n",
+       "      <td>2.2</td>\n",
+       "      <td>0.060</td>\n",
+       "      <td>28.0</td>\n",
+       "      <td>65.0</td>\n",
+       "      <td>0.99590</td>\n",
+       "      <td>3.30</td>\n",
+       "      <td>0.43</td>\n",
+       "      <td>10.1</td>\n",
+       "      <td>5</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>278</th>\n",
+       "      <td>6.6</td>\n",
+       "      <td>0.735</td>\n",
+       "      <td>0.02</td>\n",
+       "      <td>7.9</td>\n",
+       "      <td>0.122</td>\n",
+       "      <td>68.0</td>\n",
+       "      <td>124.0</td>\n",
+       "      <td>0.99940</td>\n",
+       "      <td>3.47</td>\n",
+       "      <td>0.53</td>\n",
+       "      <td>9.9</td>\n",
+       "      <td>5</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>490</th>\n",
+       "      <td>8.6</td>\n",
+       "      <td>0.490</td>\n",
+       "      <td>0.51</td>\n",
+       "      <td>2.0</td>\n",
+       "      <td>0.422</td>\n",
+       "      <td>16.0</td>\n",
+       "      <td>62.0</td>\n",
+       "      <td>0.99790</td>\n",
+       "      <td>3.03</td>\n",
+       "      <td>1.17</td>\n",
+       "      <td>9.0</td>\n",
+       "      <td>5</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>...</th>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>896</th>\n",
+       "      <td>10.4</td>\n",
+       "      <td>0.430</td>\n",
+       "      <td>0.50</td>\n",
+       "      <td>2.3</td>\n",
+       "      <td>0.068</td>\n",
+       "      <td>13.0</td>\n",
+       "      <td>19.0</td>\n",
+       "      <td>0.99600</td>\n",
+       "      <td>3.10</td>\n",
+       "      <td>0.87</td>\n",
+       "      <td>11.4</td>\n",
+       "      <td>6</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>922</th>\n",
+       "      <td>7.6</td>\n",
+       "      <td>1.580</td>\n",
+       "      <td>0.00</td>\n",
+       "      <td>2.1</td>\n",
+       "      <td>0.137</td>\n",
+       "      <td>5.0</td>\n",
+       "      <td>9.0</td>\n",
+       "      <td>0.99476</td>\n",
+       "      <td>3.50</td>\n",
+       "      <td>0.40</td>\n",
+       "      <td>10.9</td>\n",
+       "      <td>3</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>219</th>\n",
+       "      <td>8.4</td>\n",
+       "      <td>0.650</td>\n",
+       "      <td>0.60</td>\n",
+       "      <td>2.1</td>\n",
+       "      <td>0.112</td>\n",
+       "      <td>12.0</td>\n",
+       "      <td>90.0</td>\n",
+       "      <td>0.99730</td>\n",
+       "      <td>3.20</td>\n",
+       "      <td>0.52</td>\n",
+       "      <td>9.2</td>\n",
+       "      <td>5</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>970</th>\n",
+       "      <td>7.3</td>\n",
+       "      <td>0.740</td>\n",
+       "      <td>0.08</td>\n",
+       "      <td>1.7</td>\n",
+       "      <td>0.094</td>\n",
+       "      <td>10.0</td>\n",
+       "      <td>45.0</td>\n",
+       "      <td>0.99576</td>\n",
+       "      <td>3.24</td>\n",
+       "      <td>0.50</td>\n",
+       "      <td>9.8</td>\n",
+       "      <td>5</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>288</th>\n",
+       "      <td>8.8</td>\n",
+       "      <td>0.520</td>\n",
+       "      <td>0.34</td>\n",
+       "      <td>2.7</td>\n",
+       "      <td>0.087</td>\n",
+       "      <td>24.0</td>\n",
+       "      <td>122.0</td>\n",
+       "      <td>0.99820</td>\n",
+       "      <td>3.26</td>\n",
+       "      <td>0.61</td>\n",
+       "      <td>9.5</td>\n",
+       "      <td>5</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "<p>240 rows × 12 columns</p>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "     fixed acidity  volatile acidity  citric acid  residual sugar  chlorides  \\\n",
+       "309            7.0             0.620         0.18             1.5      0.062   \n",
+       "528            8.3             0.760         0.29             4.2      0.075   \n",
+       "150            8.2             0.570         0.26             2.2      0.060   \n",
+       "278            6.6             0.735         0.02             7.9      0.122   \n",
+       "490            8.6             0.490         0.51             2.0      0.422   \n",
+       "..             ...               ...          ...             ...        ...   \n",
+       "896           10.4             0.430         0.50             2.3      0.068   \n",
+       "922            7.6             1.580         0.00             2.1      0.137   \n",
+       "219            8.4             0.650         0.60             2.1      0.112   \n",
+       "970            7.3             0.740         0.08             1.7      0.094   \n",
+       "288            8.8             0.520         0.34             2.7      0.087   \n",
+       "\n",
+       "     free sulfur dioxide  total sulfur dioxide  density    pH  sulphates  \\\n",
+       "309                  7.0                  50.0  0.99510  3.08       0.60   \n",
+       "528                 12.0                  16.0  0.99650  3.45       0.68   \n",
+       "150                 28.0                  65.0  0.99590  3.30       0.43   \n",
+       "278                 68.0                 124.0  0.99940  3.47       0.53   \n",
+       "490                 16.0                  62.0  0.99790  3.03       1.17   \n",
+       "..                   ...                   ...      ...   ...        ...   \n",
+       "896                 13.0                  19.0  0.99600  3.10       0.87   \n",
+       "922                  5.0                   9.0  0.99476  3.50       0.40   \n",
+       "219                 12.0                  90.0  0.99730  3.20       0.52   \n",
+       "970                 10.0                  45.0  0.99576  3.24       0.50   \n",
+       "288                 24.0                 122.0  0.99820  3.26       0.61   \n",
+       "\n",
+       "     alcohol  quality  \n",
+       "309      9.3        5  \n",
+       "528     11.5        6  \n",
+       "150     10.1        5  \n",
+       "278      9.9        5  \n",
+       "490      9.0        5  \n",
+       "..       ...      ...  \n",
+       "896     11.4        6  \n",
+       "922     10.9        3  \n",
+       "219      9.2        5  \n",
+       "970      9.8        5  \n",
+       "288      9.5        5  \n",
+       "\n",
+       "[240 rows x 12 columns]"
+      ]
+     },
+     "execution_count": 246,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "wine_validate"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 247,
+   "id": "b1e11d58-2415-4a17-9883-410e1bd15c21",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>fixed acidity</th>\n",
+       "      <th>volatile acidity</th>\n",
+       "      <th>citric acid</th>\n",
+       "      <th>residual sugar</th>\n",
+       "      <th>chlorides</th>\n",
+       "      <th>free sulfur dioxide</th>\n",
+       "      <th>total sulfur dioxide</th>\n",
+       "      <th>density</th>\n",
+       "      <th>pH</th>\n",
+       "      <th>sulphates</th>\n",
+       "      <th>alcohol</th>\n",
+       "      <th>quality</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>7.8</td>\n",
+       "      <td>0.88</td>\n",
+       "      <td>0.00</td>\n",
+       "      <td>2.6</td>\n",
+       "      <td>0.098</td>\n",
+       "      <td>25.0</td>\n",
+       "      <td>67.0</td>\n",
+       "      <td>0.99680</td>\n",
+       "      <td>3.20</td>\n",
+       "      <td>0.68</td>\n",
+       "      <td>9.8</td>\n",
+       "      <td>5</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>7.8</td>\n",
+       "      <td>0.76</td>\n",
+       "      <td>0.04</td>\n",
+       "      <td>2.3</td>\n",
+       "      <td>0.092</td>\n",
+       "      <td>15.0</td>\n",
+       "      <td>54.0</td>\n",
+       "      <td>0.99700</td>\n",
+       "      <td>3.26</td>\n",
+       "      <td>0.65</td>\n",
+       "      <td>9.8</td>\n",
+       "      <td>5</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>11.2</td>\n",
+       "      <td>0.28</td>\n",
+       "      <td>0.56</td>\n",
+       "      <td>1.9</td>\n",
+       "      <td>0.075</td>\n",
+       "      <td>17.0</td>\n",
+       "      <td>60.0</td>\n",
+       "      <td>0.99800</td>\n",
+       "      <td>3.16</td>\n",
+       "      <td>0.58</td>\n",
+       "      <td>9.8</td>\n",
+       "      <td>6</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>6</th>\n",
+       "      <td>7.9</td>\n",
+       "      <td>0.60</td>\n",
+       "      <td>0.06</td>\n",
+       "      <td>1.6</td>\n",
+       "      <td>0.069</td>\n",
+       "      <td>15.0</td>\n",
+       "      <td>59.0</td>\n",
+       "      <td>0.99640</td>\n",
+       "      <td>3.30</td>\n",
+       "      <td>0.46</td>\n",
+       "      <td>9.4</td>\n",
+       "      <td>5</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>8</th>\n",
+       "      <td>7.8</td>\n",
+       "      <td>0.58</td>\n",
+       "      <td>0.02</td>\n",
+       "      <td>2.0</td>\n",
+       "      <td>0.073</td>\n",
+       "      <td>9.0</td>\n",
+       "      <td>18.0</td>\n",
+       "      <td>0.99680</td>\n",
+       "      <td>3.36</td>\n",
+       "      <td>0.57</td>\n",
+       "      <td>9.5</td>\n",
+       "      <td>7</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>...</th>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1128</th>\n",
+       "      <td>6.2</td>\n",
+       "      <td>0.70</td>\n",
+       "      <td>0.15</td>\n",
+       "      <td>5.1</td>\n",
+       "      <td>0.076</td>\n",
+       "      <td>13.0</td>\n",
+       "      <td>27.0</td>\n",
+       "      <td>0.99622</td>\n",
+       "      <td>3.54</td>\n",
+       "      <td>0.60</td>\n",
+       "      <td>11.9</td>\n",
+       "      <td>6</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1133</th>\n",
+       "      <td>6.7</td>\n",
+       "      <td>0.32</td>\n",
+       "      <td>0.44</td>\n",
+       "      <td>2.4</td>\n",
+       "      <td>0.061</td>\n",
+       "      <td>24.0</td>\n",
+       "      <td>34.0</td>\n",
+       "      <td>0.99484</td>\n",
+       "      <td>3.29</td>\n",
+       "      <td>0.80</td>\n",
+       "      <td>11.6</td>\n",
+       "      <td>7</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1135</th>\n",
+       "      <td>5.8</td>\n",
+       "      <td>0.61</td>\n",
+       "      <td>0.11</td>\n",
+       "      <td>1.8</td>\n",
+       "      <td>0.066</td>\n",
+       "      <td>18.0</td>\n",
+       "      <td>28.0</td>\n",
+       "      <td>0.99483</td>\n",
+       "      <td>3.55</td>\n",
+       "      <td>0.66</td>\n",
+       "      <td>10.9</td>\n",
+       "      <td>6</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1137</th>\n",
+       "      <td>5.4</td>\n",
+       "      <td>0.74</td>\n",
+       "      <td>0.09</td>\n",
+       "      <td>1.7</td>\n",
+       "      <td>0.089</td>\n",
+       "      <td>16.0</td>\n",
+       "      <td>26.0</td>\n",
+       "      <td>0.99402</td>\n",
+       "      <td>3.67</td>\n",
+       "      <td>0.56</td>\n",
+       "      <td>11.6</td>\n",
+       "      <td>6</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1140</th>\n",
+       "      <td>6.2</td>\n",
+       "      <td>0.60</td>\n",
+       "      <td>0.08</td>\n",
+       "      <td>2.0</td>\n",
+       "      <td>0.090</td>\n",
+       "      <td>32.0</td>\n",
+       "      <td>44.0</td>\n",
+       "      <td>0.99490</td>\n",
+       "      <td>3.45</td>\n",
+       "      <td>0.58</td>\n",
+       "      <td>10.5</td>\n",
+       "      <td>5</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "<p>343 rows × 12 columns</p>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "      fixed acidity  volatile acidity  citric acid  residual sugar  chlorides  \\\n",
+       "1               7.8              0.88         0.00             2.6      0.098   \n",
+       "2               7.8              0.76         0.04             2.3      0.092   \n",
+       "3              11.2              0.28         0.56             1.9      0.075   \n",
+       "6               7.9              0.60         0.06             1.6      0.069   \n",
+       "8               7.8              0.58         0.02             2.0      0.073   \n",
+       "...             ...               ...          ...             ...        ...   \n",
+       "1128            6.2              0.70         0.15             5.1      0.076   \n",
+       "1133            6.7              0.32         0.44             2.4      0.061   \n",
+       "1135            5.8              0.61         0.11             1.8      0.066   \n",
+       "1137            5.4              0.74         0.09             1.7      0.089   \n",
+       "1140            6.2              0.60         0.08             2.0      0.090   \n",
+       "\n",
+       "      free sulfur dioxide  total sulfur dioxide  density    pH  sulphates  \\\n",
+       "1                    25.0                  67.0  0.99680  3.20       0.68   \n",
+       "2                    15.0                  54.0  0.99700  3.26       0.65   \n",
+       "3                    17.0                  60.0  0.99800  3.16       0.58   \n",
+       "6                    15.0                  59.0  0.99640  3.30       0.46   \n",
+       "8                     9.0                  18.0  0.99680  3.36       0.57   \n",
+       "...                   ...                   ...      ...   ...        ...   \n",
+       "1128                 13.0                  27.0  0.99622  3.54       0.60   \n",
+       "1133                 24.0                  34.0  0.99484  3.29       0.80   \n",
+       "1135                 18.0                  28.0  0.99483  3.55       0.66   \n",
+       "1137                 16.0                  26.0  0.99402  3.67       0.56   \n",
+       "1140                 32.0                  44.0  0.99490  3.45       0.58   \n",
+       "\n",
+       "      alcohol  quality  \n",
+       "1         9.8        5  \n",
+       "2         9.8        5  \n",
+       "3         9.8        6  \n",
+       "6         9.4        5  \n",
+       "8         9.5        7  \n",
+       "...       ...      ...  \n",
+       "1128     11.9        6  \n",
+       "1133     11.6        7  \n",
+       "1135     10.9        6  \n",
+       "1137     11.6        6  \n",
+       "1140     10.5        5  \n",
+       "\n",
+       "[343 rows x 12 columns]"
+      ]
+     },
+     "execution_count": 247,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "wine_test"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 248,
+   "id": "88bc3d8e-10c7-4887-83c8-21e0274349a2",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# isolate results from features\n",
+    "# now this is the hassle most of the time - convert whatever the input is to whatever it's expected to be\n",
+    "# in our case, pandas data frames to pytorch tensors\n",
+    "train_features = torch.tensor(wine_train.drop('quality', axis=1).values.astype(np.float32))\n",
+    "train_target = torch.tensor(wine_train['quality'].values.astype(np.int64))\n",
+    "validate_features = torch.tensor(wine_validate.drop('quality', axis=1).values.astype(np.float32))\n",
+    "validate_target = torch.tensor(wine_validate['quality'].values.astype(np.int64))\n",
+    "test_features = torch.tensor(wine_test.drop('quality', axis=1).values.astype(np.float32))\n",
+    "test_target = torch.tensor(wine_test['quality'].values.astype(np.int64))\n",
+    "\n",
+    "train_data = torch.utils.data.TensorDataset(train_features, train_target)\n",
+    "validate_data = torch.utils.data.TensorDataset(validate_features, validate_target)\n",
+    "test_data = torch.utils.data.TensorDataset(test_features, test_target)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 249,
+   "id": "21a01ec6-44d9-4f50-b093-e80816345dd0",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(tensor([ 7.5000,  0.7100,  0.0000,  1.6000,  0.0920, 22.0000, 31.0000,  0.9963,\n",
+       "          3.3800,  0.5800, 10.0000]),\n",
+       " tensor(6))"
+      ]
+     },
+     "execution_count": 249,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "train_data[0]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 250,
+   "id": "84198c1f-5a45-45ad-94f8-9e3de5a4efd5",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(tensor([ 7.0000,  0.6200,  0.1800,  1.5000,  0.0620,  7.0000, 50.0000,  0.9951,\n",
+       "          3.0800,  0.6000,  9.3000]),\n",
+       " tensor(5))"
+      ]
+     },
+     "execution_count": 250,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "validate_data[0]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 251,
+   "id": "c0685118-8baf-48d1-a554-a8b94a85d39a",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(tensor([ 7.8000,  0.8800,  0.0000,  2.6000,  0.0980, 25.0000, 67.0000,  0.9968,\n",
+       "          3.2000,  0.6800,  9.8000]),\n",
+       " tensor(5))"
+      ]
+     },
+     "execution_count": 251,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "test_data[0]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 252,
+   "id": "c59d990a-3f85-4981-a288-8877ef1fa34c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# create data loaders\n",
+    "train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)\n",
+    "validate_loader = torch.utils.data.DataLoader(validate_data, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)\n",
+    "test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 224,
+   "id": "e3aaa169-e086-4a19-865d-557c0c2aa210",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "accelerated by mps\n"
+     ]
+    }
+   ],
+   "source": [
+    "# check what device to use\n",
+    "from torch.accelerator import is_available, current_accelerator\n",
+    "if (is_available()):\n",
+    "    device = current_accelerator().type\n",
+    "    print(f\"accelerated by {device}\")\n",
+    "else:\n",
+    "    device = \"cpu\"\n",
+    "    print(\"no accelerator\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 229,
+   "id": "a1ee2d09-caed-4d39-ab51-cc27504753ef",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "**** DEFINING Linear/ReLU layer stack ****\n"
+     ]
+    }
+   ],
+   "source": [
+    "# define the model\n",
+    "print(\"**** DEFINING Linear/ReLU layer stack ****\")\n",
+    "\n",
+    "from torch import nn\n",
+    "\n",
+    "class SeqNN(nn.Module):\n",
+    "    def __init__(self):\n",
+    "        super().__init__()\n",
+    "        self.linear_relu_stack = nn.Sequential(\n",
+    "            nn.Linear(11, 512),\n",
+    "            nn.ReLU(),\n",
+    "            nn.Linear(512, 10),\n",
+    "            nn.Softmax(),\n",
+    "        )\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        logits = self.linear_relu_stack(x)\n",
+    "        return logits\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 230,
+   "id": "9e95ca59-e228-4b84-b3ec-f12c6e6dcb02",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "SeqNN(\n",
+      "  (linear_relu_stack): Sequential(\n",
+      "    (0): Linear(in_features=11, out_features=512, bias=True)\n",
+      "    (1): ReLU()\n",
+      "    (2): Linear(in_features=512, out_features=10, bias=True)\n",
+      "    (3): Softmax(dim=None)\n",
+      "  )\n",
+      ")\n"
+     ]
+    }
+   ],
+   "source": [
+    "# show what the model looks like\n",
+    "model = SeqNN() #.to(device)\n",
+    "print(model)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 227,
+   "id": "6df0edaa-bf1a-4ff5-ae41-2f1c03811c7c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# define the loss function and optimizer\n",
+    "loss_fn = torch.nn.CrossEntropyLoss()\n",
+    "optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.90)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 257,
+   "id": "102c4fc5-0798-40c5-bb32-9c87fc7707d5",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def train_epoch(ep_num):\n",
+    "    running_loss = 0.\n",
+    "    last_loss = 0.\n",
+    "\n",
+    "    report_every = 5\n",
+    "\n",
+    "    for idx, batch in enumerate(train_loader):\n",
+    "        # unpack the next batch\n",
+    "        features, labels = batch\n",
+    "    \n",
+    "        # ensure features are on the accelerator device (if any)\n",
+    "        features.to(device)\n",
+    "        \n",
+    "        # zero the optimizer gradients\n",
+    "        optimizer.zero_grad()\n",
+    "    \n",
+    "        # feed it to model\n",
+    "        outputs = model(features)\n",
+    "    \n",
+    "        # calculate the loss against the expected label\n",
+    "        loss = loss_fn(outputs, labels)\n",
+    "        loss.backward()\n",
+    "    \n",
+    "        # feed the new loss info to optimizer\n",
+    "        optimizer.step()\n",
+    "    \n",
+    "        # calculate whether this is an improvement or a degradation\n",
+    "        running_loss += loss.item()\n",
+    "\n",
+    "        # report loss uppon every 20 items\n",
+    "        if idx % report_every == (report_every - 1):\n",
+    "            last_loss = running_loss / report_every\n",
+    "            print(\"  batch {} loss: {}\".format(idx + 1, last_loss))\n",
+    "            tb_x = ep_num * len(train_loader) + idx + 1\n",
+    "            print(\"  loss/train: {}/{}\".format(last_loss, tb_x))\n",
+    "            running_loss = 0.\n",
+    "\n",
+    "    return last_loss"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 258,
+   "id": "cdf2c8ae-d024-4687-b1df-23804f3d9816",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "**** TRAINING NN on supplied data ****\n",
+      "EPOCH 1\n",
+      "  batch 5 loss: 2.4407766819000245\n",
+      "  loss/train: 2.4407766819000245/5\n",
+      "  batch 10 loss: 2.451565170288086\n",
+      "  loss/train: 2.451565170288086/10\n",
+      "  batch 15 loss: 2.43599009513855\n",
+      "  loss/train: 2.43599009513855/15\n",
+      "LOSS train 2.43599009513855 / valid 2.4329328536987305\n",
+      "EPOCH 2\n",
+      "  batch 5 loss: 2.444344425201416\n",
+      "  loss/train: 2.444344425201416/23\n",
+      "  batch 10 loss: 2.4506516456604004\n",
+      "  loss/train: 2.4506516456604004/28\n",
+      "  batch 15 loss: 2.4280431270599365\n",
+      "  loss/train: 2.4280431270599365/33\n",
+      "LOSS train 2.4280431270599365 / valid 2.4371285438537598\n",
+      "EPOCH 3\n",
+      "  batch 5 loss: 2.444630241394043\n",
+      "  loss/train: 2.444630241394043/41\n",
+      "  batch 10 loss: 2.4294994354248045\n",
+      "  loss/train: 2.4294994354248045/46\n",
+      "  batch 15 loss: 2.4486732959747313\n",
+      "  loss/train: 2.4486732959747313/51\n",
+      "LOSS train 2.4486732959747313 / valid 2.437175989151001\n",
+      "EPOCH 4\n",
+      "  batch 5 loss: 2.449464464187622\n",
+      "  loss/train: 2.449464464187622/59\n",
+      "  batch 10 loss: 2.435447835922241\n",
+      "  loss/train: 2.435447835922241/64\n",
+      "  batch 15 loss: 2.4380492210388183\n",
+      "  loss/train: 2.4380492210388183/69\n",
+      "LOSS train 2.4380492210388183 / valid 2.437028408050537\n",
+      "EPOCH 5\n",
+      "  batch 5 loss: 2.4409915447235107\n",
+      "  loss/train: 2.4409915447235107/77\n",
+      "  batch 10 loss: 2.443865346908569\n",
+      "  loss/train: 2.443865346908569/82\n",
+      "  batch 15 loss: 2.4377002716064453\n",
+      "  loss/train: 2.4377002716064453/87\n",
+      "LOSS train 2.4377002716064453 / valid 2.4371585845947266\n"
+     ]
+    }
+   ],
+   "source": [
+    "# train the model\n",
+    "print(\"**** TRAINING NN on supplied data ****\")\n",
+    "\n",
+    "epoch_no = 0\n",
+    "num_epochs = 5\n",
+    "best_vloss = 1000000\n",
+    "\n",
+    "for epoch in range(num_epochs):\n",
+    "    print(\"EPOCH\", epoch + 1)\n",
+    "\n",
+    "    model.train(True)\n",
+    "    avg_loss = train_epoch(epoch_no)\n",
+    "\n",
+    "    running_vloss = 0.0\n",
+    "    model.eval()\n",
+    "    with torch.no_grad():\n",
+    "        for idx, vdata in enumerate(validate_loader):\n",
+    "            vfeatures, vlabels = vdata\n",
+    "            voutputs = model(vfeatures)\n",
+    "            vloss = loss_fn(voutputs, vlabels)\n",
+    "            running_vloss += vloss\n",
+    "\n",
+    "    avg_vloss = running_vloss / (idx + 1)\n",
+    "    print(\"LOSS train {} / valid {}\".format(avg_loss, avg_vloss))\n",
+    "\n",
+    "    if avg_vloss < best_vloss:\n",
+    "        best_vloss = avg_vloss\n",
+    "\n",
+    "    epoch_no += 1"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 259,
+   "id": "4ef4f484-5664-45dd-a5aa-2909276cedc8",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "SeqNN(\n",
+       "  (linear_relu_stack): Sequential(\n",
+       "    (0): Linear(in_features=11, out_features=512, bias=True)\n",
+       "    (1): ReLU()\n",
+       "    (2): Linear(in_features=512, out_features=10, bias=True)\n",
+       "    (3): Softmax(dim=None)\n",
+       "  )\n",
+       ")"
+      ]
+     },
+     "execution_count": 259,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "model.eval()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "1ff5d2c3-e979-434f-a6ac-355c68832952",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "pytorch-26",
+   "language": "python",
+   "name": "pytorch-26"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.12.9"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}

文件差異過大導致無法顯示
+ 201 - 0
code/wine-sklearn.ipynb


+ 71 - 0
code/wine-sklearn.py

@@ -0,0 +1,71 @@
+#!/bin/false
+# ^^^ this just means don't allow this to be executed as a stand-alone script
+
+# the basic imports
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import seaborn as sns
+# but also reporting on the model
+from sklearn.metrics import classification_report, confusion_matrix
+
+# load data, extract just the features, and just the labels
+wine_data = pd.read_csv("./WineQT.csv", delimiter=",")
+wine_features = wine_data.drop("quality", axis=1).drop("Id", axis=1)
+wine_labels = np.ravel(wine_data['quality'])
+
+# split the dataset into train and test subsets
+# note, while it may be tempting to get creative with variable names, such as
+# features_train, features_test, labels_train, labels_test...
+# it's WAY TOO MUCH typing, and most examples use x for features (as in, input
+# data) and y for labels (as in, result)
+from sklearn.model_selection import train_test_split
+
+x_train, x_test, y_train, y_test = train_test_split(wine_features, wine_labels, test_size=0.5, random_state=50)
+
+# normalise the data (meaning spread it ALL out on a scale between a..b)
+from sklearn.preprocessing import StandardScaler
+
+scaler = StandardScaler().fit(x_train)
+x_train = scaler.transform(x_train)
+x_test = scaler.transform(x_test)
+
+# train the SVC model
+print("**** TESTING C-Support Vector Classification ****")
+
+from sklearn.svm import SVC
+
+svc_model = SVC()
+svc_model.fit(x_train, y_train)
+
+# now test the fitness with the test subset
+svc_y_predict = svc_model.predict(x_test)
+
+# visualise it
+svc_cm = np.array(confusion_matrix(y_test, svc_y_predict, labels=[0,1,2,3,4,5,6,7,8,9,10]))
+svc_conf_matrix = pd.DataFrame(svc_cm)
+print(svc_conf_matrix)
+
+# visualise it in a nice picture
+sns.heatmap(svc_conf_matrix, annot=True, fmt='g')
+plt.show()
+
+# # train the NuSVC model
+# print("**** TESTING Nu-Support Vector Classification ****")
+
+# from sklearn.svm import NuSVC
+
+# nusvc_model = NuSVC(nu=0.2)
+# nusvc_model.fit(x_train, y_train)
+
+# # now test the fitness with the test subset
+# nusvc_y_predict = svc_model.predict(x_test)
+
+# # visualise it
+# nu_cm = np.array(confusion_matrix(y_test, nusvc_y_predict, labels=[0,1,2,3,4,5,6,7,8,9,10]))
+# nu_conf_matrix = pd.DataFrame(nu_cm)
+# print(nu_conf_matrix)
+
+# # visualise it in a nice picture
+# sns.heatmap(nu_conf_matrix, annot=True, fmt='g')
+# plt.show()

+ 0 - 0
env-pytorch-26.yml → envs/env-pytorch-26.yml


+ 0 - 0
env-sklearn-16.yml → envs/env-sklearn-16.yml


+ 0 - 0
env-tf-216.yml → envs/env-tf-216.yml


部分文件因文件數量過多而無法顯示