roundabout,
created on Thursday, 7 September 2023, 05:10:06 (1694063406),
received on Tuesday, 26 March 2024, 14:36:51 (1711463811)
Author identity: vlad <vlad.muntoiu@gmail.com>
f9fcd79d28ed1f0247132f50afe3b72ddde2363b
main.ipynb
@@ -17,7 +17,7 @@
"source": [ "import os # for file operations\n", "import json # for loading the annotations file\n", "from PIL import Image, ImageDraw, ImageOps # for processing the image data\n","from PIL import Image, ImageDraw, ImageOps, ImageEnhance # for processing the image data\n","import numpy as np\n", "from random import shuffle\n", "import nvidia.cudnn\n",
@@ -98,6 +98,7 @@
"\n", "num_images = len(via_annotations[\"_via_img_metadata\"])\n", "from tqdm.notebook import tqdm # the progress bar\n", "import random\n","for i, image_data in tqdm(enumerate(via_annotations[\"_via_img_metadata\"].items()), total=num_images):\n", " image_info = image_data[1]\n", " image_id = image_data[0]\n",
@@ -166,16 +167,13 @@
" pencil.rectangle(points, \"white\")\n", " \n", " # Resize the mask as well\n", " mask = mask.resize(IMAGE_RESOLUTION, resample=Image.NEAREST)\n"," \n"," # Convert them to arrays\n"," masked = np.asarray(Image.composite(img, black, mask)) / 255.0\n"," mask = np.asarray(mask) / 255.0\n"," img = np.asarray(img) / 255.0\n"," mask = mask.resize(IMAGE_RESOLUTION, resample=Image.NEAREST).convert(\"1\")\n", " masked = Image.composite(img, black, mask)\n","\n", " # Plot the three images\n", " import matplotlib.pyplot as plt\n", " f, axarr = plt.subplots(1, 3)\n", " plt.axis(\"on\")\n"," axarr[0].imshow(img)\n", " axarr[0].set_title(\"Original\")\n", " \n",
@@ -186,17 +184,34 @@
" axarr[2].set_title(\"Mask\")\n", "\n", " plt.show()\n", "\n"," # Convert the PIL image to a NumPy array and normalize\n","\n"," \n"," # Get the class label from the directory structure\n", " class_label = os.path.basename(os.path.dirname(filename))\n", "\n", " # Convert class label to one-hot encoding\n", " class_index = CLASS_NAMES.index(class_label) # Assuming you have a list of class labels\n"," class_index = CLASS_NAMES.index(class_label)\n"," class_one_hot = to_categorical(class_index, num_classes=CLASS_COUNT)\n", " \n", " # Preprocess\n", " n_copies = random.randint(1, 6)\n", " f, axarr = plt.subplots(1, n_copies + 1)\n", " plt.axis(\"off\")\n", " for j in range(0, n_copies):\n", " random.seed(i+j // random.randint(1, 4) + random.randint(0, 16777216))\n", " \n", " preprocessed = masked.rotate(random.randint(0, 3) * 90 + random.randint(-22, 22))\n", " sharpness = ImageEnhance.Sharpness(img)\n", " color = ImageEnhance.Color(img)\n", " contrast = ImageEnhance.Contrast(img)\n", " brightness = ImageEnhance.Brightness(img)\n", " \n", " sharpness.enhance(random.uniform(0.5, 1.5))\n", " color.enhance(random.uniform(0.75, 1.25))\n", " contrast.enhance(random.uniform(0.625, 1.375))\n", " contrast.enhance(random.uniform(0.5, 1.5))\n","\n", " annotations.append((masked, class_one_hot))\n"," annotations.append((np.asarray(preprocessed) / 255.0, class_one_hot))\n", " axarr[j].imshow(preprocessed)\n","\n", "# Shuffle the data\n", "shuffle(annotations)\n",
@@ -262,20 +277,6 @@
"## training" ] }, {"cell_type": "code","execution_count": null,"id": "4ab166f7-dd8f-4ae2-9062-ded152a31d1e","metadata": {},"outputs": [],"source": ["# clear GPU memory\n","from tensorflow.python.framework import ops\n","\n","ops.reset_default_graph()\n","tf.keras.backend.clear_session()"]},{ "cell_type": "code", "execution_count": null,
@@ -292,6 +293,7 @@
" epochs=EPOCHS,\n", " verbose=0,\n", " callbacks=[TqdmCallback()],\n", " validation_split=0.2\n",")" ] },
@@ -356,7 +358,9 @@
"cell_type": "code", "execution_count": null, "id": "2d6b22f3-1a2b-4094-aecf-b6bd8f5fba02", "metadata": {},"metadata": { "scrolled": true },"outputs": [], "source": [ "history = model.fit(\n",
@@ -366,6 +370,7 @@
" epochs=EPOCHS,\n", " verbose=0,\n", " callbacks=[TqdmCallback()],\n", " validation_split=0.2\n",")" ] },
@@ -396,7 +401,8 @@
"source": [ "from tensorflow.keras.models import load_model\n", "from tensorflow.keras.preprocessing.image import load_img, img_to_array\n", "import numpy as np""import numpy as np\n", "import matplotlib.pyplot as plt"] }, {
@@ -406,7 +412,7 @@
"metadata": {}, "outputs": [], "source": [ "model = load_model(\"model.h5\")""model = load_model(\"model.h5\", compile=False)"] }, {