{ "cells": [ { "cell_type": "markdown", "id": "bf68dfaf-37f5-463d-b72e-2120d27f0b0d", "metadata": {}, "source": [ "Imports for this Notebook:" ] }, { "cell_type": "code", "execution_count": 234, "id": "63528a79-842d-44ec-9109-29df1621cef8", "metadata": {}, "outputs": [], "source": [ "# the basic imports\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "\n", "# this loads the data, apparently\n", "from torch.utils.data import DataLoader" ] }, { "cell_type": "markdown", "id": "ee773350-56e9-4601-a33e-b97c46a2cafc", "metadata": {}, "source": [ "Load and split the wine quality dataset with Pandas." ] }, { "cell_type": "code", "execution_count": 243, "id": "d106a41b-e75c-4ca0-96e1-0e5b117992bf", "metadata": {}, "outputs": [], "source": [ "# load data (only get rid of the Id column)\n", "wine_data = pd.read_csv(\"./WineQT.csv\", delimiter=\",\").drop(\"Id\", axis=1)\n", "\n", "# split the dataset into model and test subsets\n", "wine_model = wine_data.sample(frac=0.7, random_state=123)\n", "wine_test = wine_data.drop(wine_train.index)\n", "\n", "# further split the training set into train and validation\n", "wine_train = wine_model.sample(frac=0.7, random_state=123)\n", "wine_validate = wine_model.drop(wine_train.index)" ] }, { "cell_type": "code", "execution_count": 244, "id": "29301f87-1919-48ae-8c9e-1a55f4fb1495", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "source: 1143\n", "train: 560 validate: 240 test: 343 == 1143\n" ] } ], "source": [ "# verify the produced data size matches the source\n", "print(\"source:\", len(wine_data))\n", "print(\"train:\", len(wine_train), \"validate:\", len(wine_validate), \"test:\", len(wine_test), \"==\", len(wine_train + wine_validate + wine_test))" ] }, { "cell_type": "code", "execution_count": 245, "id": "3087bcdc-5c50-4b79-a2d8-f442a8d81825", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | fixed acidity | \n", "volatile acidity | \n", "citric acid | \n", "residual sugar | \n", "chlorides | \n", "free sulfur dioxide | \n", "total sulfur dioxide | \n", "density | \n", "pH | \n", "sulphates | \n", "alcohol | \n", "quality | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|
711 | \n", "7.5 | \n", "0.71 | \n", "0.00 | \n", "1.6 | \n", "0.092 | \n", "22.0 | \n", "31.0 | \n", "0.99635 | \n", "3.38 | \n", "0.58 | \n", "10.0 | \n", "6 | \n", "
936 | \n", "9.1 | \n", "0.34 | \n", "0.42 | \n", "1.8 | \n", "0.058 | \n", "9.0 | \n", "18.0 | \n", "0.99392 | \n", "3.18 | \n", "0.55 | \n", "11.4 | \n", "5 | \n", "
683 | \n", "10.4 | \n", "0.26 | \n", "0.48 | \n", "1.9 | \n", "0.066 | \n", "6.0 | \n", "10.0 | \n", "0.99724 | \n", "3.33 | \n", "0.87 | \n", "10.9 | \n", "6 | \n", "
904 | \n", "8.5 | \n", "0.40 | \n", "0.40 | \n", "6.3 | \n", "0.050 | \n", "3.0 | \n", "10.0 | \n", "0.99566 | \n", "3.28 | \n", "0.56 | \n", "12.0 | \n", "4 | \n", "
198 | \n", "8.9 | \n", "0.40 | \n", "0.32 | \n", "5.6 | \n", "0.087 | \n", "10.0 | \n", "47.0 | \n", "0.99910 | \n", "3.38 | \n", "0.77 | \n", "10.5 | \n", "7 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
1076 | \n", "7.5 | \n", "0.38 | \n", "0.57 | \n", "2.3 | \n", "0.106 | \n", "5.0 | \n", "12.0 | \n", "0.99605 | \n", "3.36 | \n", "0.55 | \n", "11.4 | \n", "6 | \n", "
709 | \n", "8.9 | \n", "0.32 | \n", "0.31 | \n", "2.0 | \n", "0.088 | \n", "12.0 | \n", "19.0 | \n", "0.99570 | \n", "3.17 | \n", "0.55 | \n", "10.4 | \n", "6 | \n", "
486 | \n", "8.1 | \n", "0.78 | \n", "0.23 | \n", "2.6 | \n", "0.059 | \n", "5.0 | \n", "15.0 | \n", "0.99700 | \n", "3.37 | \n", "0.56 | \n", "11.3 | \n", "5 | \n", "
182 | \n", "7.7 | \n", "0.41 | \n", "0.76 | \n", "1.8 | \n", "0.611 | \n", "8.0 | \n", "45.0 | \n", "0.99680 | \n", "3.06 | \n", "1.26 | \n", "9.4 | \n", "5 | \n", "
206 | \n", "8.7 | \n", "0.52 | \n", "0.09 | \n", "2.5 | \n", "0.091 | \n", "20.0 | \n", "49.0 | \n", "0.99760 | \n", "3.34 | \n", "0.86 | \n", "10.6 | \n", "7 | \n", "
560 rows × 12 columns
\n", "\n", " | fixed acidity | \n", "volatile acidity | \n", "citric acid | \n", "residual sugar | \n", "chlorides | \n", "free sulfur dioxide | \n", "total sulfur dioxide | \n", "density | \n", "pH | \n", "sulphates | \n", "alcohol | \n", "quality | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|
309 | \n", "7.0 | \n", "0.620 | \n", "0.18 | \n", "1.5 | \n", "0.062 | \n", "7.0 | \n", "50.0 | \n", "0.99510 | \n", "3.08 | \n", "0.60 | \n", "9.3 | \n", "5 | \n", "
528 | \n", "8.3 | \n", "0.760 | \n", "0.29 | \n", "4.2 | \n", "0.075 | \n", "12.0 | \n", "16.0 | \n", "0.99650 | \n", "3.45 | \n", "0.68 | \n", "11.5 | \n", "6 | \n", "
150 | \n", "8.2 | \n", "0.570 | \n", "0.26 | \n", "2.2 | \n", "0.060 | \n", "28.0 | \n", "65.0 | \n", "0.99590 | \n", "3.30 | \n", "0.43 | \n", "10.1 | \n", "5 | \n", "
278 | \n", "6.6 | \n", "0.735 | \n", "0.02 | \n", "7.9 | \n", "0.122 | \n", "68.0 | \n", "124.0 | \n", "0.99940 | \n", "3.47 | \n", "0.53 | \n", "9.9 | \n", "5 | \n", "
490 | \n", "8.6 | \n", "0.490 | \n", "0.51 | \n", "2.0 | \n", "0.422 | \n", "16.0 | \n", "62.0 | \n", "0.99790 | \n", "3.03 | \n", "1.17 | \n", "9.0 | \n", "5 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
896 | \n", "10.4 | \n", "0.430 | \n", "0.50 | \n", "2.3 | \n", "0.068 | \n", "13.0 | \n", "19.0 | \n", "0.99600 | \n", "3.10 | \n", "0.87 | \n", "11.4 | \n", "6 | \n", "
922 | \n", "7.6 | \n", "1.580 | \n", "0.00 | \n", "2.1 | \n", "0.137 | \n", "5.0 | \n", "9.0 | \n", "0.99476 | \n", "3.50 | \n", "0.40 | \n", "10.9 | \n", "3 | \n", "
219 | \n", "8.4 | \n", "0.650 | \n", "0.60 | \n", "2.1 | \n", "0.112 | \n", "12.0 | \n", "90.0 | \n", "0.99730 | \n", "3.20 | \n", "0.52 | \n", "9.2 | \n", "5 | \n", "
970 | \n", "7.3 | \n", "0.740 | \n", "0.08 | \n", "1.7 | \n", "0.094 | \n", "10.0 | \n", "45.0 | \n", "0.99576 | \n", "3.24 | \n", "0.50 | \n", "9.8 | \n", "5 | \n", "
288 | \n", "8.8 | \n", "0.520 | \n", "0.34 | \n", "2.7 | \n", "0.087 | \n", "24.0 | \n", "122.0 | \n", "0.99820 | \n", "3.26 | \n", "0.61 | \n", "9.5 | \n", "5 | \n", "
240 rows × 12 columns
\n", "\n", " | fixed acidity | \n", "volatile acidity | \n", "citric acid | \n", "residual sugar | \n", "chlorides | \n", "free sulfur dioxide | \n", "total sulfur dioxide | \n", "density | \n", "pH | \n", "sulphates | \n", "alcohol | \n", "quality | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | \n", "7.8 | \n", "0.88 | \n", "0.00 | \n", "2.6 | \n", "0.098 | \n", "25.0 | \n", "67.0 | \n", "0.99680 | \n", "3.20 | \n", "0.68 | \n", "9.8 | \n", "5 | \n", "
2 | \n", "7.8 | \n", "0.76 | \n", "0.04 | \n", "2.3 | \n", "0.092 | \n", "15.0 | \n", "54.0 | \n", "0.99700 | \n", "3.26 | \n", "0.65 | \n", "9.8 | \n", "5 | \n", "
3 | \n", "11.2 | \n", "0.28 | \n", "0.56 | \n", "1.9 | \n", "0.075 | \n", "17.0 | \n", "60.0 | \n", "0.99800 | \n", "3.16 | \n", "0.58 | \n", "9.8 | \n", "6 | \n", "
6 | \n", "7.9 | \n", "0.60 | \n", "0.06 | \n", "1.6 | \n", "0.069 | \n", "15.0 | \n", "59.0 | \n", "0.99640 | \n", "3.30 | \n", "0.46 | \n", "9.4 | \n", "5 | \n", "
8 | \n", "7.8 | \n", "0.58 | \n", "0.02 | \n", "2.0 | \n", "0.073 | \n", "9.0 | \n", "18.0 | \n", "0.99680 | \n", "3.36 | \n", "0.57 | \n", "9.5 | \n", "7 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
1128 | \n", "6.2 | \n", "0.70 | \n", "0.15 | \n", "5.1 | \n", "0.076 | \n", "13.0 | \n", "27.0 | \n", "0.99622 | \n", "3.54 | \n", "0.60 | \n", "11.9 | \n", "6 | \n", "
1133 | \n", "6.7 | \n", "0.32 | \n", "0.44 | \n", "2.4 | \n", "0.061 | \n", "24.0 | \n", "34.0 | \n", "0.99484 | \n", "3.29 | \n", "0.80 | \n", "11.6 | \n", "7 | \n", "
1135 | \n", "5.8 | \n", "0.61 | \n", "0.11 | \n", "1.8 | \n", "0.066 | \n", "18.0 | \n", "28.0 | \n", "0.99483 | \n", "3.55 | \n", "0.66 | \n", "10.9 | \n", "6 | \n", "
1137 | \n", "5.4 | \n", "0.74 | \n", "0.09 | \n", "1.7 | \n", "0.089 | \n", "16.0 | \n", "26.0 | \n", "0.99402 | \n", "3.67 | \n", "0.56 | \n", "11.6 | \n", "6 | \n", "
1140 | \n", "6.2 | \n", "0.60 | \n", "0.08 | \n", "2.0 | \n", "0.090 | \n", "32.0 | \n", "44.0 | \n", "0.99490 | \n", "3.45 | \n", "0.58 | \n", "10.5 | \n", "5 | \n", "
343 rows × 12 columns
\n", "