{ "cells": [ { "cell_type": "markdown", "id": "bf68dfaf-37f5-463d-b72e-2120d27f0b0d", "metadata": {}, "source": [ "Imports for this Notebook:" ] }, { "cell_type": "code", "execution_count": 234, "id": "63528a79-842d-44ec-9109-29df1621cef8", "metadata": {}, "outputs": [], "source": [ "# the basic imports\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "\n", "# this loads the data, apparently\n", "from torch.utils.data import DataLoader" ] }, { "cell_type": "markdown", "id": "ee773350-56e9-4601-a33e-b97c46a2cafc", "metadata": {}, "source": [ "Load and split the wine quality dataset with Pandas." ] }, { "cell_type": "code", "execution_count": 243, "id": "d106a41b-e75c-4ca0-96e1-0e5b117992bf", "metadata": {}, "outputs": [], "source": [ "# load data (only get rid of the Id column)\n", "wine_data = pd.read_csv(\"./WineQT.csv\", delimiter=\",\").drop(\"Id\", axis=1)\n", "\n", "# split the dataset into model and test subsets\n", "wine_model = wine_data.sample(frac=0.7, random_state=123)\n", "wine_test = wine_data.drop(wine_train.index)\n", "\n", "# further split the training set into train and validation\n", "wine_train = wine_model.sample(frac=0.7, random_state=123)\n", "wine_validate = wine_model.drop(wine_train.index)" ] }, { "cell_type": "code", "execution_count": 244, "id": "29301f87-1919-48ae-8c9e-1a55f4fb1495", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "source: 1143\n", "train: 560 validate: 240 test: 343 == 1143\n" ] } ], "source": [ "# verify the produced data size matches the source\n", "print(\"source:\", len(wine_data))\n", "print(\"train:\", len(wine_train), \"validate:\", len(wine_validate), \"test:\", len(wine_test), \"==\", len(wine_train + wine_validate + wine_test))" ] }, { "cell_type": "code", "execution_count": 245, "id": "3087bcdc-5c50-4b79-a2d8-f442a8d81825", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
fixed acidityvolatile aciditycitric acidresidual sugarchloridesfree sulfur dioxidetotal sulfur dioxidedensitypHsulphatesalcoholquality
7117.50.710.001.60.09222.031.00.996353.380.5810.06
9369.10.340.421.80.0589.018.00.993923.180.5511.45
68310.40.260.481.90.0666.010.00.997243.330.8710.96
9048.50.400.406.30.0503.010.00.995663.280.5612.04
1988.90.400.325.60.08710.047.00.999103.380.7710.57
.......................................
10767.50.380.572.30.1065.012.00.996053.360.5511.46
7098.90.320.312.00.08812.019.00.995703.170.5510.46
4868.10.780.232.60.0595.015.00.997003.370.5611.35
1827.70.410.761.80.6118.045.00.996803.061.269.45
2068.70.520.092.50.09120.049.00.997603.340.8610.67
\n", "

560 rows × 12 columns

\n", "
" ], "text/plain": [ " fixed acidity volatile acidity citric acid residual sugar chlorides \\\n", "711 7.5 0.71 0.00 1.6 0.092 \n", "936 9.1 0.34 0.42 1.8 0.058 \n", "683 10.4 0.26 0.48 1.9 0.066 \n", "904 8.5 0.40 0.40 6.3 0.050 \n", "198 8.9 0.40 0.32 5.6 0.087 \n", "... ... ... ... ... ... \n", "1076 7.5 0.38 0.57 2.3 0.106 \n", "709 8.9 0.32 0.31 2.0 0.088 \n", "486 8.1 0.78 0.23 2.6 0.059 \n", "182 7.7 0.41 0.76 1.8 0.611 \n", "206 8.7 0.52 0.09 2.5 0.091 \n", "\n", " free sulfur dioxide total sulfur dioxide density pH sulphates \\\n", "711 22.0 31.0 0.99635 3.38 0.58 \n", "936 9.0 18.0 0.99392 3.18 0.55 \n", "683 6.0 10.0 0.99724 3.33 0.87 \n", "904 3.0 10.0 0.99566 3.28 0.56 \n", "198 10.0 47.0 0.99910 3.38 0.77 \n", "... ... ... ... ... ... \n", "1076 5.0 12.0 0.99605 3.36 0.55 \n", "709 12.0 19.0 0.99570 3.17 0.55 \n", "486 5.0 15.0 0.99700 3.37 0.56 \n", "182 8.0 45.0 0.99680 3.06 1.26 \n", "206 20.0 49.0 0.99760 3.34 0.86 \n", "\n", " alcohol quality \n", "711 10.0 6 \n", "936 11.4 5 \n", "683 10.9 6 \n", "904 12.0 4 \n", "198 10.5 7 \n", "... ... ... \n", "1076 11.4 6 \n", "709 10.4 6 \n", "486 11.3 5 \n", "182 9.4 5 \n", "206 10.6 7 \n", "\n", "[560 rows x 12 columns]" ] }, "execution_count": 245, "metadata": {}, "output_type": "execute_result" } ], "source": [ "wine_train" ] }, { "cell_type": "code", "execution_count": 246, "id": "241f2cdb-a7fb-4165-be11-95ca6510082b", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
fixed acidityvolatile aciditycitric acidresidual sugarchloridesfree sulfur dioxidetotal sulfur dioxidedensitypHsulphatesalcoholquality
3097.00.6200.181.50.0627.050.00.995103.080.609.35
5288.30.7600.294.20.07512.016.00.996503.450.6811.56
1508.20.5700.262.20.06028.065.00.995903.300.4310.15
2786.60.7350.027.90.12268.0124.00.999403.470.539.95
4908.60.4900.512.00.42216.062.00.997903.031.179.05
.......................................
89610.40.4300.502.30.06813.019.00.996003.100.8711.46
9227.61.5800.002.10.1375.09.00.994763.500.4010.93
2198.40.6500.602.10.11212.090.00.997303.200.529.25
9707.30.7400.081.70.09410.045.00.995763.240.509.85
2888.80.5200.342.70.08724.0122.00.998203.260.619.55
\n", "

240 rows × 12 columns

\n", "
" ], "text/plain": [ " fixed acidity volatile acidity citric acid residual sugar chlorides \\\n", "309 7.0 0.620 0.18 1.5 0.062 \n", "528 8.3 0.760 0.29 4.2 0.075 \n", "150 8.2 0.570 0.26 2.2 0.060 \n", "278 6.6 0.735 0.02 7.9 0.122 \n", "490 8.6 0.490 0.51 2.0 0.422 \n", ".. ... ... ... ... ... \n", "896 10.4 0.430 0.50 2.3 0.068 \n", "922 7.6 1.580 0.00 2.1 0.137 \n", "219 8.4 0.650 0.60 2.1 0.112 \n", "970 7.3 0.740 0.08 1.7 0.094 \n", "288 8.8 0.520 0.34 2.7 0.087 \n", "\n", " free sulfur dioxide total sulfur dioxide density pH sulphates \\\n", "309 7.0 50.0 0.99510 3.08 0.60 \n", "528 12.0 16.0 0.99650 3.45 0.68 \n", "150 28.0 65.0 0.99590 3.30 0.43 \n", "278 68.0 124.0 0.99940 3.47 0.53 \n", "490 16.0 62.0 0.99790 3.03 1.17 \n", ".. ... ... ... ... ... \n", "896 13.0 19.0 0.99600 3.10 0.87 \n", "922 5.0 9.0 0.99476 3.50 0.40 \n", "219 12.0 90.0 0.99730 3.20 0.52 \n", "970 10.0 45.0 0.99576 3.24 0.50 \n", "288 24.0 122.0 0.99820 3.26 0.61 \n", "\n", " alcohol quality \n", "309 9.3 5 \n", "528 11.5 6 \n", "150 10.1 5 \n", "278 9.9 5 \n", "490 9.0 5 \n", ".. ... ... \n", "896 11.4 6 \n", "922 10.9 3 \n", "219 9.2 5 \n", "970 9.8 5 \n", "288 9.5 5 \n", "\n", "[240 rows x 12 columns]" ] }, "execution_count": 246, "metadata": {}, "output_type": "execute_result" } ], "source": [ "wine_validate" ] }, { "cell_type": "code", "execution_count": 247, "id": "b1e11d58-2415-4a17-9883-410e1bd15c21", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
fixed acidityvolatile aciditycitric acidresidual sugarchloridesfree sulfur dioxidetotal sulfur dioxidedensitypHsulphatesalcoholquality
17.80.880.002.60.09825.067.00.996803.200.689.85
27.80.760.042.30.09215.054.00.997003.260.659.85
311.20.280.561.90.07517.060.00.998003.160.589.86
67.90.600.061.60.06915.059.00.996403.300.469.45
87.80.580.022.00.0739.018.00.996803.360.579.57
.......................................
11286.20.700.155.10.07613.027.00.996223.540.6011.96
11336.70.320.442.40.06124.034.00.994843.290.8011.67
11355.80.610.111.80.06618.028.00.994833.550.6610.96
11375.40.740.091.70.08916.026.00.994023.670.5611.66
11406.20.600.082.00.09032.044.00.994903.450.5810.55
\n", "

343 rows × 12 columns

\n", "
" ], "text/plain": [ " fixed acidity volatile acidity citric acid residual sugar chlorides \\\n", "1 7.8 0.88 0.00 2.6 0.098 \n", "2 7.8 0.76 0.04 2.3 0.092 \n", "3 11.2 0.28 0.56 1.9 0.075 \n", "6 7.9 0.60 0.06 1.6 0.069 \n", "8 7.8 0.58 0.02 2.0 0.073 \n", "... ... ... ... ... ... \n", "1128 6.2 0.70 0.15 5.1 0.076 \n", "1133 6.7 0.32 0.44 2.4 0.061 \n", "1135 5.8 0.61 0.11 1.8 0.066 \n", "1137 5.4 0.74 0.09 1.7 0.089 \n", "1140 6.2 0.60 0.08 2.0 0.090 \n", "\n", " free sulfur dioxide total sulfur dioxide density pH sulphates \\\n", "1 25.0 67.0 0.99680 3.20 0.68 \n", "2 15.0 54.0 0.99700 3.26 0.65 \n", "3 17.0 60.0 0.99800 3.16 0.58 \n", "6 15.0 59.0 0.99640 3.30 0.46 \n", "8 9.0 18.0 0.99680 3.36 0.57 \n", "... ... ... ... ... ... \n", "1128 13.0 27.0 0.99622 3.54 0.60 \n", "1133 24.0 34.0 0.99484 3.29 0.80 \n", "1135 18.0 28.0 0.99483 3.55 0.66 \n", "1137 16.0 26.0 0.99402 3.67 0.56 \n", "1140 32.0 44.0 0.99490 3.45 0.58 \n", "\n", " alcohol quality \n", "1 9.8 5 \n", "2 9.8 5 \n", "3 9.8 6 \n", "6 9.4 5 \n", "8 9.5 7 \n", "... ... ... \n", "1128 11.9 6 \n", "1133 11.6 7 \n", "1135 10.9 6 \n", "1137 11.6 6 \n", "1140 10.5 5 \n", "\n", "[343 rows x 12 columns]" ] }, "execution_count": 247, "metadata": {}, "output_type": "execute_result" } ], "source": [ "wine_test" ] }, { "cell_type": "code", "execution_count": 248, "id": "88bc3d8e-10c7-4887-83c8-21e0274349a2", "metadata": {}, "outputs": [], "source": [ "# isolate results from features\n", "# now this is the hassle most of the time - convert whatever the input is to whatever it's expected to be\n", "# in our case, pandas data frames to pytorch tensors\n", "train_features = torch.tensor(wine_train.drop('quality', axis=1).values.astype(np.float32))\n", "train_target = torch.tensor(wine_train['quality'].values.astype(np.int64))\n", "validate_features = torch.tensor(wine_validate.drop('quality', axis=1).values.astype(np.float32))\n", "validate_target = torch.tensor(wine_validate['quality'].values.astype(np.int64))\n", "test_features = torch.tensor(wine_test.drop('quality', axis=1).values.astype(np.float32))\n", "test_target = torch.tensor(wine_test['quality'].values.astype(np.int64))\n", "\n", "train_data = torch.utils.data.TensorDataset(train_features, train_target)\n", "validate_data = torch.utils.data.TensorDataset(validate_features, validate_target)\n", "test_data = torch.utils.data.TensorDataset(test_features, test_target)" ] }, { "cell_type": "code", "execution_count": 249, "id": "21a01ec6-44d9-4f50-b093-e80816345dd0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([ 7.5000, 0.7100, 0.0000, 1.6000, 0.0920, 22.0000, 31.0000, 0.9963,\n", " 3.3800, 0.5800, 10.0000]),\n", " tensor(6))" ] }, "execution_count": 249, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_data[0]" ] }, { "cell_type": "code", "execution_count": 250, "id": "84198c1f-5a45-45ad-94f8-9e3de5a4efd5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([ 7.0000, 0.6200, 0.1800, 1.5000, 0.0620, 7.0000, 50.0000, 0.9951,\n", " 3.0800, 0.6000, 9.3000]),\n", " tensor(5))" ] }, "execution_count": 250, "metadata": {}, "output_type": "execute_result" } ], "source": [ "validate_data[0]" ] }, { "cell_type": "code", "execution_count": 251, "id": "c0685118-8baf-48d1-a554-a8b94a85d39a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([ 7.8000, 0.8800, 0.0000, 2.6000, 0.0980, 25.0000, 67.0000, 0.9968,\n", " 3.2000, 0.6800, 9.8000]),\n", " tensor(5))" ] }, "execution_count": 251, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_data[0]" ] }, { "cell_type": "code", "execution_count": 252, "id": "c59d990a-3f85-4981-a288-8877ef1fa34c", "metadata": {}, "outputs": [], "source": [ "# create data loaders\n", "train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)\n", "validate_loader = torch.utils.data.DataLoader(validate_data, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)\n", "test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)" ] }, { "cell_type": "code", "execution_count": 224, "id": "e3aaa169-e086-4a19-865d-557c0c2aa210", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accelerated by mps\n" ] } ], "source": [ "# check what device to use\n", "from torch.accelerator import is_available, current_accelerator\n", "if (is_available()):\n", " device = current_accelerator().type\n", " print(f\"accelerated by {device}\")\n", "else:\n", " device = \"cpu\"\n", " print(\"no accelerator\")" ] }, { "cell_type": "code", "execution_count": 229, "id": "a1ee2d09-caed-4d39-ab51-cc27504753ef", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "**** DEFINING Linear/ReLU layer stack ****\n" ] } ], "source": [ "# define the model\n", "print(\"**** DEFINING Linear/ReLU layer stack ****\")\n", "\n", "from torch import nn\n", "\n", "class SeqNN(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.linear_relu_stack = nn.Sequential(\n", " nn.Linear(11, 512),\n", " nn.ReLU(),\n", " nn.Linear(512, 10),\n", " nn.Softmax(),\n", " )\n", "\n", " def forward(self, x):\n", " logits = self.linear_relu_stack(x)\n", " return logits\n" ] }, { "cell_type": "code", "execution_count": 230, "id": "9e95ca59-e228-4b84-b3ec-f12c6e6dcb02", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SeqNN(\n", " (linear_relu_stack): Sequential(\n", " (0): Linear(in_features=11, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=512, out_features=10, bias=True)\n", " (3): Softmax(dim=None)\n", " )\n", ")\n" ] } ], "source": [ "# show what the model looks like\n", "model = SeqNN() #.to(device)\n", "print(model)" ] }, { "cell_type": "code", "execution_count": 227, "id": "6df0edaa-bf1a-4ff5-ae41-2f1c03811c7c", "metadata": {}, "outputs": [], "source": [ "# define the loss function and optimizer\n", "loss_fn = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.90)" ] }, { "cell_type": "code", "execution_count": 257, "id": "102c4fc5-0798-40c5-bb32-9c87fc7707d5", "metadata": {}, "outputs": [], "source": [ "def train_epoch(ep_num):\n", " running_loss = 0.\n", " last_loss = 0.\n", "\n", " report_every = 5\n", "\n", " for idx, batch in enumerate(train_loader):\n", " # unpack the next batch\n", " features, labels = batch\n", " \n", " # ensure features are on the accelerator device (if any)\n", " features.to(device)\n", " \n", " # zero the optimizer gradients\n", " optimizer.zero_grad()\n", " \n", " # feed it to model\n", " outputs = model(features)\n", " \n", " # calculate the loss against the expected label\n", " loss = loss_fn(outputs, labels)\n", " loss.backward()\n", " \n", " # feed the new loss info to optimizer\n", " optimizer.step()\n", " \n", " # calculate whether this is an improvement or a degradation\n", " running_loss += loss.item()\n", "\n", " # report loss uppon every 20 items\n", " if idx % report_every == (report_every - 1):\n", " last_loss = running_loss / report_every\n", " print(\" batch {} loss: {}\".format(idx + 1, last_loss))\n", " tb_x = ep_num * len(train_loader) + idx + 1\n", " print(\" loss/train: {}/{}\".format(last_loss, tb_x))\n", " running_loss = 0.\n", "\n", " return last_loss" ] }, { "cell_type": "code", "execution_count": 258, "id": "cdf2c8ae-d024-4687-b1df-23804f3d9816", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "**** TRAINING NN on supplied data ****\n", "EPOCH 1\n", " batch 5 loss: 2.4407766819000245\n", " loss/train: 2.4407766819000245/5\n", " batch 10 loss: 2.451565170288086\n", " loss/train: 2.451565170288086/10\n", " batch 15 loss: 2.43599009513855\n", " loss/train: 2.43599009513855/15\n", "LOSS train 2.43599009513855 / valid 2.4329328536987305\n", "EPOCH 2\n", " batch 5 loss: 2.444344425201416\n", " loss/train: 2.444344425201416/23\n", " batch 10 loss: 2.4506516456604004\n", " loss/train: 2.4506516456604004/28\n", " batch 15 loss: 2.4280431270599365\n", " loss/train: 2.4280431270599365/33\n", "LOSS train 2.4280431270599365 / valid 2.4371285438537598\n", "EPOCH 3\n", " batch 5 loss: 2.444630241394043\n", " loss/train: 2.444630241394043/41\n", " batch 10 loss: 2.4294994354248045\n", " loss/train: 2.4294994354248045/46\n", " batch 15 loss: 2.4486732959747313\n", " loss/train: 2.4486732959747313/51\n", "LOSS train 2.4486732959747313 / valid 2.437175989151001\n", "EPOCH 4\n", " batch 5 loss: 2.449464464187622\n", " loss/train: 2.449464464187622/59\n", " batch 10 loss: 2.435447835922241\n", " loss/train: 2.435447835922241/64\n", " batch 15 loss: 2.4380492210388183\n", " loss/train: 2.4380492210388183/69\n", "LOSS train 2.4380492210388183 / valid 2.437028408050537\n", "EPOCH 5\n", " batch 5 loss: 2.4409915447235107\n", " loss/train: 2.4409915447235107/77\n", " batch 10 loss: 2.443865346908569\n", " loss/train: 2.443865346908569/82\n", " batch 15 loss: 2.4377002716064453\n", " loss/train: 2.4377002716064453/87\n", "LOSS train 2.4377002716064453 / valid 2.4371585845947266\n" ] } ], "source": [ "# train the model\n", "print(\"**** TRAINING NN on supplied data ****\")\n", "\n", "epoch_no = 0\n", "num_epochs = 5\n", "best_vloss = 1000000\n", "\n", "for epoch in range(num_epochs):\n", " print(\"EPOCH\", epoch + 1)\n", "\n", " model.train(True)\n", " avg_loss = train_epoch(epoch_no)\n", "\n", " running_vloss = 0.0\n", " model.eval()\n", " with torch.no_grad():\n", " for idx, vdata in enumerate(validate_loader):\n", " vfeatures, vlabels = vdata\n", " voutputs = model(vfeatures)\n", " vloss = loss_fn(voutputs, vlabels)\n", " running_vloss += vloss\n", "\n", " avg_vloss = running_vloss / (idx + 1)\n", " print(\"LOSS train {} / valid {}\".format(avg_loss, avg_vloss))\n", "\n", " if avg_vloss < best_vloss:\n", " best_vloss = avg_vloss\n", "\n", " epoch_no += 1" ] }, { "cell_type": "code", "execution_count": 259, "id": "4ef4f484-5664-45dd-a5aa-2909276cedc8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SeqNN(\n", " (linear_relu_stack): Sequential(\n", " (0): Linear(in_features=11, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=512, out_features=10, bias=True)\n", " (3): Softmax(dim=None)\n", " )\n", ")" ] }, "execution_count": 259, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.eval()" ] }, { "cell_type": "code", "execution_count": null, "id": "1ff5d2c3-e979-434f-a6ac-355c68832952", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "pytorch-26", "language": "python", "name": "pytorch-26" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 5 }