By using this site, you agree to have cookies stored on your device, strictly for functional purposes, such as storing your session and preferences.

Dismiss

Data preprocessing

roundabout,
created on Thursday, 7 September 2023, 05:10:06 (1694063406), received on Tuesday, 26 March 2024, 14:36:51 (1711463811)
Author identity: vlad <vlad.muntoiu@gmail.com>

f9fcd79d28ed1f0247132f50afe3b72ddde2363b

main.ipynb

@@ -17,7 +17,7 @@

                                
                                
                                
                            
                                
                                    
                                        
                                            
                                               "source": [
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "import os  # for file operations\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "import json  # for loading the annotations file\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "from PIL import Image, ImageDraw, ImageOps  # for processing the image data\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "from PIL import Image, ImageDraw, ImageOps, ImageEnhance  # for processing the image data\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "import numpy as np\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "from random import shuffle\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "import nvidia.cudnn\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        

@@ -98,6 +98,7 @@

                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "num_images = len(via_annotations[\"_via_img_metadata\"])\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "from tqdm.notebook import tqdm  # the progress bar\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "import random\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "for i, image_data in tqdm(enumerate(via_annotations[\"_via_img_metadata\"].items()), total=num_images):\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "    image_info = image_data[1]\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "    image_id = image_data[0]\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        

@@ -166,16 +167,13 @@

                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "            pencil.rectangle(points, \"white\")\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "    \n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "    # Resize the mask as well\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "    mask = mask.resize(IMAGE_RESOLUTION, resample=Image.NEAREST)\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "    \n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "    # Convert them to arrays\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "    masked = np.asarray(Image.composite(img, black, mask)) / 255.0\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "    mask = np.asarray(mask) / 255.0\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "    img = np.asarray(img) / 255.0\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "    mask = mask.resize(IMAGE_RESOLUTION, resample=Image.NEAREST).convert(\"1\")\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "    masked = Image.composite(img, black, mask)\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "    # Plot the three images\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "    import matplotlib.pyplot as plt\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "    f, axarr = plt.subplots(1, 3)\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "    plt.axis(\"on\")\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "    axarr[0].imshow(img)\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "    axarr[0].set_title(\"Original\")\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "    \n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        

@@ -186,17 +184,34 @@

                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "    axarr[2].set_title(\"Mask\")\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "    plt.show()\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "    # Convert the PIL image to a NumPy array and normalize\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "    \n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "    # Get the class label from the directory structure\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "    class_label = os.path.basename(os.path.dirname(filename))\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "    # Convert class label to one-hot encoding\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "    class_index = CLASS_NAMES.index(class_label) # Assuming you have a list of class labels\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "    class_index = CLASS_NAMES.index(class_label)\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "    class_one_hot = to_categorical(class_index, num_classes=CLASS_COUNT)\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "    \n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "    # Preprocess\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "    n_copies = random.randint(1, 6)\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "    f, axarr = plt.subplots(1, n_copies + 1)\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "    plt.axis(\"off\")\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "    for j in range(0, n_copies):\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "        random.seed(i+j // random.randint(1, 4) + random.randint(0, 16777216))\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "        \n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "        preprocessed = masked.rotate(random.randint(0, 3) * 90 + random.randint(-22, 22))\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "        sharpness = ImageEnhance.Sharpness(img)\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "        color = ImageEnhance.Color(img)\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "        contrast = ImageEnhance.Contrast(img)\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "        brightness = ImageEnhance.Brightness(img)\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "        \n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "        sharpness.enhance(random.uniform(0.5, 1.5))\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "        color.enhance(random.uniform(0.75, 1.25))\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "        contrast.enhance(random.uniform(0.625, 1.375))\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "        contrast.enhance(random.uniform(0.5, 1.5))\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "    annotations.append((masked, class_one_hot))\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "        annotations.append((np.asarray(preprocessed) / 255.0, class_one_hot))\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "        axarr[j].imshow(preprocessed)\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "# Shuffle the data\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "shuffle(annotations)\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        

@@ -262,20 +277,6 @@

                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "## training"
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                               ]
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                              },
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                          {
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                           "cell_type": "code",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                           "execution_count": null,
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                           "id": "4ab166f7-dd8f-4ae2-9062-ded152a31d1e",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                           "metadata": {},
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                           "outputs": [],
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                           "source": [
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "# clear GPU memory\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "from tensorflow.python.framework import ops\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "ops.reset_default_graph()\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "tf.keras.backend.clear_session()"
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                           ]
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                          },
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                              {
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                               "cell_type": "code",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                               "execution_count": null,
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        

@@ -292,6 +293,7 @@

                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "    epochs=EPOCHS,\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "    verbose=0,\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "    callbacks=[TqdmCallback()],\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "    validation_split=0.2\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                ")"
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                               ]
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                              },
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        

@@ -356,7 +358,9 @@

                                
                                
                                
                            
                                
                                    
                                        
                                            
                                               "cell_type": "code",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                               "execution_count": null,
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                               "id": "2d6b22f3-1a2b-4094-aecf-b6bd8f5fba02",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                           "metadata": {},
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                           "metadata": {
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "scrolled": true
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                           },
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                               "outputs": [],
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                               "source": [
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "history = model.fit(\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        

@@ -366,6 +370,7 @@

                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "    epochs=EPOCHS,\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "    verbose=0,\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "    callbacks=[TqdmCallback()],\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "    validation_split=0.2\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                ")"
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                               ]
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                              },
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        

@@ -396,7 +401,8 @@

                                
                                
                                
                            
                                
                                    
                                        
                                            
                                               "source": [
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "from tensorflow.keras.models import load_model\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                                "from tensorflow.keras.preprocessing.image import load_img, img_to_array\n",
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "import numpy as np"
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "import numpy as np\n",
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "import matplotlib.pyplot as plt"
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                               ]
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                              },
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                              {
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        

@@ -406,7 +412,7 @@

                                
                                
                                
                            
                                
                                    
                                        
                                            
                                               "metadata": {},
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                               "outputs": [],
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                               "source": [
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "model = load_model(\"model.h5\")"
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            "model = load_model(\"model.h5\", compile=False)"
                                        
                                        
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                               ]
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                              },
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                            
                                              {
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                                
                                    
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                    
                                
                                
                                
                            
                            
                                

model.h5