{ "cells": [ { "cell_type": "markdown", "id": "b8f36fba-e6c1-4f65-ae88-35e795d8e89e", "metadata": {}, "source": [ "Load the _Wine Quality Dataset (Combined)_ data." ] }, { "cell_type": "code", "execution_count": 1, "id": "96634688-0fce-478d-a570-edad08bd37bc", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "\n", "data = pd.read_csv(\"../WineQuality.csv\")" ] }, { "cell_type": "markdown", "id": "b7b0e979-1712-4422-9baa-4076c58828c9", "metadata": {}, "source": [ "Sample the data." ] }, { "cell_type": "code", "execution_count": 2, "id": "642be8ad-0912-452d-b914-47189e60dd71", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Unnamed: 0fixed acidityvolatile aciditycitric acidresidual sugarchloridesfree sulfur dioxidetotal sulfur dioxidedensitypHsulphatesalcoholqualityType
027327.40.1700.291.40.04723.0107.00.993903.520.6510.46White Wine
126075.30.3100.3810.50.03153.0140.00.993213.340.4611.76White Wine
216534.70.1450.291.00.04235.090.00.990803.760.4911.36White Wine
332646.90.2600.294.20.04333.0114.00.990203.160.3112.56White Wine
449316.40.4500.071.10.03010.0131.00.990502.970.2810.85White Wine
.............................................
3248028385.00.2550.222.70.04346.0153.00.992383.750.7611.36White Wine
3248164146.60.3600.5211.30.0468.0110.00.996603.070.469.45White Wine
3248211266.30.2000.241.70.05236.0135.00.993743.800.6610.86White Wine
3248329246.20.2000.335.40.02821.075.00.990123.360.4113.57White Wine
3248454628.10.2800.4615.40.05932.0177.01.000403.270.589.04White Wine
\n", "

32485 rows × 14 columns

\n", "
" ], "text/plain": [ " Unnamed: 0 fixed acidity volatile acidity citric acid \\\n", "0 2732 7.4 0.170 0.29 \n", "1 2607 5.3 0.310 0.38 \n", "2 1653 4.7 0.145 0.29 \n", "3 3264 6.9 0.260 0.29 \n", "4 4931 6.4 0.450 0.07 \n", "... ... ... ... ... \n", "32480 2838 5.0 0.255 0.22 \n", "32481 6414 6.6 0.360 0.52 \n", "32482 1126 6.3 0.200 0.24 \n", "32483 2924 6.2 0.200 0.33 \n", "32484 5462 8.1 0.280 0.46 \n", "\n", " residual sugar chlorides free sulfur dioxide total sulfur dioxide \\\n", "0 1.4 0.047 23.0 107.0 \n", "1 10.5 0.031 53.0 140.0 \n", "2 1.0 0.042 35.0 90.0 \n", "3 4.2 0.043 33.0 114.0 \n", "4 1.1 0.030 10.0 131.0 \n", "... ... ... ... ... \n", "32480 2.7 0.043 46.0 153.0 \n", "32481 11.3 0.046 8.0 110.0 \n", "32482 1.7 0.052 36.0 135.0 \n", "32483 5.4 0.028 21.0 75.0 \n", "32484 15.4 0.059 32.0 177.0 \n", "\n", " density pH sulphates alcohol quality Type \n", "0 0.99390 3.52 0.65 10.4 6 White Wine \n", "1 0.99321 3.34 0.46 11.7 6 White Wine \n", "2 0.99080 3.76 0.49 11.3 6 White Wine \n", "3 0.99020 3.16 0.31 12.5 6 White Wine \n", "4 0.99050 2.97 0.28 10.8 5 White Wine \n", "... ... ... ... ... ... ... \n", "32480 0.99238 3.75 0.76 11.3 6 White Wine \n", "32481 0.99660 3.07 0.46 9.4 5 White Wine \n", "32482 0.99374 3.80 0.66 10.8 6 White Wine \n", "32483 0.99012 3.36 0.41 13.5 7 White Wine \n", "32484 1.00040 3.27 0.58 9.0 4 White Wine \n", "\n", "[32485 rows x 14 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data" ] }, { "cell_type": "markdown", "id": "09008208-91d4-49e0-806d-71461d5c3fdc", "metadata": {}, "source": [ "Retain the useful 11 features and isolate quality alone as a label, then split them both into training/test subsets." ] }, { "cell_type": "code", "execution_count": 3, "id": "1eb05609-e631-4a26-8ce6-d94dd6f25911", "metadata": {}, "outputs": [], "source": [ "x_data = data.drop(data.columns[0], axis=1).drop([\"quality\",\"Type\"], axis=1)\n", "y_data = data.quality" ] }, { "cell_type": "markdown", "id": "040a8912", "metadata": {}, "source": [ "Split the samples into 80/20 train/test subsets." ] }, { "cell_type": "code", "execution_count": 4, "id": "c8dea41a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using 25988 samples for training and 6497 samples for training.\n", "Total of 32485 records, dataset size is 32485 rows.\n", "Training set has a shape of (32485, 11), labels have a shape of (32485,)\n" ] } ], "source": [ "train_size = int(len(x_data) * 0.8)\n", "x_train = x_data[:train_size]\n", "y_train = y_data[:train_size]\n", "x_test = x_data[train_size:]\n", "y_test = y_data[train_size:]\n", "\n", "print(\"Using {} samples for training and {} samples for training.\\n\".format(len(x_train), len(x_test)) +\n", " \"Total of {} records, dataset size is {} rows.\\n\".format(len(x_train) + len(x_test), len(x_data)) +\n", " \"Training set has a shape of {}, labels have a shape of {}\".format(x_data.shape, y_data.shape))" ] }, { "cell_type": "markdown", "id": "c4f5ae2f", "metadata": {}, "source": [ "Create a sequential model definition, one deep RELU intermediate layer with softmax output and 10 possible values." ] }, { "cell_type": "code", "execution_count": 5, "id": "ecd1f4fd", "metadata": {}, "outputs": [], "source": [ "import keras\n", "from keras import layers\n", "\n", "classifier_init = keras.Sequential([\n", " layers.Dense(11, activation=\"relu\"),\n", " layers.Dense(44, activation=\"relu\"),\n", " layers.Dense(10, activation=\"softmax\")\n", "])" ] }, { "cell_type": "markdown", "id": "693762d0", "metadata": {}, "source": [ "Compile the model with adam optimiser and sparse categorical cross-entropy loss function. Track accuracy." ] }, { "cell_type": "code", "execution_count": 6, "id": "775ce661", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-04-16 20:56:16.639749: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 Max\n", "2025-04-16 20:56:16.639783: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 64.00 GB\n", "2025-04-16 20:56:16.639792: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 24.00 GB\n", "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1744826176.639810 59696050 pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.\n", "I0000 00:00:1744826176.639835 59696050 pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: )\n" ] } ], "source": [ "classifier_init.compile(optimizer=\"adam\",\n", " loss=\"sparse_categorical_crossentropy\",\n", " metrics=[\"accuracy\"])" ] }, { "cell_type": "markdown", "id": "de4aee83", "metadata": {}, "source": [ "Fit the model across 20 epochs with batches of 1000 samples, using a further 80/20 split for training and validation subsets." ] }, { "cell_type": "code", "execution_count": 7, "id": "94db127a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/20\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2025-04-16 20:56:17.150147: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 44ms/step - accuracy: 0.0138 - loss: 32.4910 - val_accuracy: 0.3705 - val_loss: 8.4827\n", "Epoch 2/20\n", "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 14ms/step - accuracy: 0.3566 - loss: 7.4516 - val_accuracy: 0.3228 - val_loss: 4.4466\n", "Epoch 3/20\n", "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 15ms/step - accuracy: 0.2947 - loss: 4.3572 - val_accuracy: 0.3032 - val_loss: 3.1995\n", "Epoch 4/20\n", "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.3092 - loss: 2.8685 - val_accuracy: 0.3322 - val_loss: 1.9692\n", "Epoch 5/20\n", "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.3381 - loss: 1.8321 - val_accuracy: 0.3405 - val_loss: 1.6127\n", "Epoch 6/20\n", "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - accuracy: 0.3474 - loss: 1.5980 - val_accuracy: 0.3621 - val_loss: 1.5295\n", "Epoch 7/20\n", "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.3604 - loss: 1.5212 - val_accuracy: 0.3588 - val_loss: 1.4604\n", "Epoch 8/20\n", "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.3757 - loss: 1.4557 - val_accuracy: 0.3809 - val_loss: 1.4095\n", "Epoch 9/20\n", "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - accuracy: 0.3969 - loss: 1.3931 - val_accuracy: 0.3948 - val_loss: 1.3708\n", "Epoch 10/20\n", "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.4164 - loss: 1.3594 - val_accuracy: 0.4071 - val_loss: 1.3422\n", "Epoch 11/20\n", "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.4286 - loss: 1.3367 - val_accuracy: 0.4204 - val_loss: 1.3164\n", "Epoch 12/20\n", "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.4420 - loss: 1.2971 - val_accuracy: 0.4240 - val_loss: 1.2971\n", "Epoch 13/20\n", "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.4422 - loss: 1.2808 - val_accuracy: 0.4323 - val_loss: 1.2831\n", "Epoch 14/20\n", "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - accuracy: 0.4440 - loss: 1.2797 - val_accuracy: 0.4375 - val_loss: 1.2734\n", "Epoch 15/20\n", "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - accuracy: 0.4593 - loss: 1.2620 - val_accuracy: 0.4396 - val_loss: 1.2652\n", "Epoch 16/20\n", "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - accuracy: 0.4515 - loss: 1.2635 - val_accuracy: 0.4357 - val_loss: 1.2605\n", "Epoch 17/20\n", "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - accuracy: 0.4544 - loss: 1.2541 - val_accuracy: 0.4356 - val_loss: 1.2582\n", "Epoch 18/20\n", "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - accuracy: 0.4617 - loss: 1.2464 - val_accuracy: 0.4365 - val_loss: 1.2535\n", "Epoch 19/20\n", "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - accuracy: 0.4540 - loss: 1.2518 - val_accuracy: 0.4454 - val_loss: 1.2491\n", "Epoch 20/20\n", "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - accuracy: 0.4577 - loss: 1.2375 - val_accuracy: 0.4361 - val_loss: 1.2477\n" ] } ], "source": [ "epochs_init = 20\n", "history_init = classifier_init.fit(x_train, y_train,\n", " epochs=epochs_init, batch_size=1000,\n", " validation_split=0.2)" ] }, { "cell_type": "markdown", "id": "16dca591", "metadata": {}, "source": [ "The above doesn't seem to be achieving a particularly good accuracy. Let's tweak the model a bit and retrain:\n", "\n", "* add an explicit input specification\n", "* add more capacity to the deep layers\n", "* randomize coefficients\n", "* tweak gradient descent parameters\n", "* add more epochs\n", "* decrease batch size a bit" ] }, { "cell_type": "code", "execution_count": 8, "id": "af4cf9c7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 24ms/step - accuracy: 0.3593 - loss: 17.3410 - val_accuracy: 0.3596 - val_loss: 3.3322\n", "Epoch 2/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.3873 - loss: 2.9919 - val_accuracy: 0.3977 - val_loss: 1.9189\n", "Epoch 3/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - accuracy: 0.4113 - loss: 1.7894 - val_accuracy: 0.4211 - val_loss: 1.3407\n", "Epoch 4/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.4389 - loss: 1.3377 - val_accuracy: 0.4422 - val_loss: 1.2794\n", "Epoch 5/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 14ms/step - accuracy: 0.4576 - loss: 1.2830 - val_accuracy: 0.4526 - val_loss: 1.2479\n", "Epoch 6/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.4629 - loss: 1.2533 - val_accuracy: 0.4617 - val_loss: 1.2289\n", "Epoch 7/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 14ms/step - accuracy: 0.4738 - loss: 1.2359 - val_accuracy: 0.4615 - val_loss: 1.2208\n", "Epoch 8/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 14ms/step - accuracy: 0.4770 - loss: 1.2270 - val_accuracy: 0.4645 - val_loss: 1.2155\n", "Epoch 9/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.4809 - loss: 1.2223 - val_accuracy: 0.4657 - val_loss: 1.2157\n", "Epoch 10/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.4810 - loss: 1.2186 - val_accuracy: 0.4701 - val_loss: 1.2144\n", "Epoch 11/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - accuracy: 0.4813 - loss: 1.2167 - val_accuracy: 0.4733 - val_loss: 1.2142\n", "Epoch 12/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - accuracy: 0.4810 - loss: 1.2151 - val_accuracy: 0.4742 - val_loss: 1.2139\n", "Epoch 13/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - accuracy: 0.4813 - loss: 1.2138 - val_accuracy: 0.4724 - val_loss: 1.2130\n", "Epoch 14/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - accuracy: 0.4817 - loss: 1.2128 - val_accuracy: 0.4743 - val_loss: 1.2123\n", "Epoch 15/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - accuracy: 0.4822 - loss: 1.2120 - val_accuracy: 0.4716 - val_loss: 1.2119\n", "Epoch 16/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.4822 - loss: 1.2115 - val_accuracy: 0.4724 - val_loss: 1.2111\n", "Epoch 17/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.4823 - loss: 1.2104 - val_accuracy: 0.4724 - val_loss: 1.2103\n", "Epoch 18/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 14ms/step - accuracy: 0.4825 - loss: 1.2101 - val_accuracy: 0.4735 - val_loss: 1.2098\n", "Epoch 19/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.4835 - loss: 1.2096 - val_accuracy: 0.4747 - val_loss: 1.2101\n", "Epoch 20/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.4828 - loss: 1.2089 - val_accuracy: 0.4769 - val_loss: 1.2081\n", "Epoch 21/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.4845 - loss: 1.2080 - val_accuracy: 0.4780 - val_loss: 1.2078\n", "Epoch 22/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.4839 - loss: 1.2070 - val_accuracy: 0.4781 - val_loss: 1.2064\n", "Epoch 23/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.4846 - loss: 1.2058 - val_accuracy: 0.4789 - val_loss: 1.2049\n", "Epoch 24/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.4834 - loss: 1.2048 - val_accuracy: 0.4802 - val_loss: 1.2044\n", "Epoch 25/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - accuracy: 0.4856 - loss: 1.2040 - val_accuracy: 0.4825 - val_loss: 1.2030\n", "Epoch 26/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.4854 - loss: 1.2031 - val_accuracy: 0.4825 - val_loss: 1.2027\n", "Epoch 27/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 14ms/step - accuracy: 0.4862 - loss: 1.2027 - val_accuracy: 0.4822 - val_loss: 1.2021\n", "Epoch 28/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.4865 - loss: 1.2016 - val_accuracy: 0.4849 - val_loss: 1.2008\n", "Epoch 29/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 14ms/step - accuracy: 0.4871 - loss: 1.2008 - val_accuracy: 0.4865 - val_loss: 1.1995\n", "Epoch 30/30\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - accuracy: 0.4867 - loss: 1.1996 - val_accuracy: 0.4848 - val_loss: 1.1989\n" ] } ], "source": [ "import tensorflow as tf\n", "tf.random.set_seed(42);\n", "tf.keras.utils.set_random_seed(42);\n", "\n", "classifier_new = keras.Sequential([\n", " layers.Input((11,)),\n", " layers.Dense(128, activation=\"relu\"),\n", " layers.Dense(128, activation=\"relu\"),\n", " layers.Dense(10, activation=\"softmax\")\n", "])\n", "\n", "classifier_new.compile(optimizer=keras.optimizers.Adam(learning_rate=0.0001),\n", " loss=\"sparse_categorical_crossentropy\",\n", " metrics=[\"accuracy\"])\n", "\n", "epochs_new = 30\n", "history_new = classifier_new.fit(x_train, y_train,\n", " epochs=epochs_new, batch_size=500,\n", " validation_split=0.3)" ] }, { "cell_type": "markdown", "id": "e0519464", "metadata": {}, "source": [ "This looks better, but how can you be sure? Visualise it!" ] }, { "cell_type": "code", "execution_count": 9, "id": "c236c890", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "val_loss = history_new.history[\"val_loss\"]\n", "trn_loss = history_new.history[\"loss\"]\n", "val_accuracy = history_new.history[\"val_accuracy\"]\n", "trn_accuracy = history_new.history[\"accuracy\"]\n", "\n", "fig, loss = plt.subplots()\n", "loss.plot(range(0, epochs_new), val_loss, \"b-\", label=\"Validation Loss\")\n", "loss.plot(range(0, epochs_new), trn_loss, \"r-\", label=\"Training Loss\")\n", "loss.set_ylabel(\"Loss\")\n", "h1, l1 = loss.get_legend_handles_labels()\n", "\n", "accr = loss.twinx()\n", "accr.plot(range(0, epochs_new), val_accuracy, \"g\", label=\"Validation Accuracy\")\n", "accr.plot(range(0, epochs_new), trn_accuracy, \"y\", label=\"Training Accuracy\")\n", "accr.set_ylabel(\"Accuracy\")\n", "h2, l2 = accr.get_legend_handles_labels()\n", "\n", "fig.legend(h1 + h2, l1 + l2, loc=(0.5, 0.5))\n", "\n", "plt.xlabel(\"Epochs\")\n", "plt.xticks(range(0, epochs_new, 5))\n", "plt.show()\n" ] }, { "cell_type": "markdown", "id": "d2a84be1", "metadata": {}, "source": [ "Let's compare that with the previous run." ] }, { "cell_type": "code", "execution_count": 10, "id": "d5731bbc", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "val_loss = history_init.history[\"val_loss\"]\n", "trn_loss = history_init.history[\"loss\"]\n", "val_accuracy = history_init.history[\"val_accuracy\"]\n", "trn_accuracy = history_init.history[\"accuracy\"]\n", "\n", "fig, loss = plt.subplots()\n", "loss.plot(range(0, epochs_init), val_loss, \"b-\", label=\"Validation Loss\")\n", "loss.plot(range(0, epochs_init), trn_loss, \"r-\", label=\"Training Loss\")\n", "loss.set_ylabel(\"Loss\")\n", "h1, l1 = loss.get_legend_handles_labels()\n", "\n", "accr = loss.twinx()\n", "accr.plot(range(0, epochs_init), val_accuracy, \"g\", label=\"Validation Accuracy\")\n", "accr.plot(range(0, epochs_init), trn_accuracy, \"y\", label=\"Training Accuracy\")\n", "accr.set_ylabel(\"Accuracy\")\n", "h2, l2 = accr.get_legend_handles_labels()\n", "\n", "fig.legend(h1 + h2, l1 + l2, loc=(0.1, 0.7))\n", "\n", "plt.xlabel(\"Epochs\")\n", "plt.xticks(range(0, epochs_init, 5))\n", "plt.show()\n" ] } ], "metadata": { "kernelspec": { "display_name": "tf-2", "language": "python", "name": "tf-2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }