# Training notebook

## Please install the requirements
(if you don't have them already)

In [2]:
%pip install torchvision
%pip install torch torchvision
%pip install --upgrade ultralytics
%pip install pillow_heif
%pip install plotly
%pip install Pillow
%pip install tqdm
%pip install ruamel.yaml

Defaulting to user installation because normal site-packages is not writeable
Collecting torchvision
  Downloading torchvision-0.18.0-cp310-cp310-manylinux1_x86_64.whl (7.0 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m0m eta [36m0:00:01[0m[36m0:00:01[0m
Collecting torch==2.3.0
  Downloading torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl (779.1 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m779.1/779.1 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m[36m0:00:07[0m
[?25hCollecting nvidia-cufft-cu12==11.0.2.54
  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m121.6/121.6 MB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m[36m0:00:01[0m
[?25hCollecting nvidia-curand-cu12==10.3.2.106
  Downloading nvidia_curand_cu12-10.3.2

In [1]:
import os  # for file operations
import json  # for loading the annotations file
from PIL import Image, ImageDraw, ImageOps, ImageEnhance  # for processing the image data
import numpy as np
from random import shuffle
import nvidia.cudnn
import torch
from pillow_heif import register_heif_opener

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

## Constants

In [2]:
DATA_PATH = "./data/"  # must contain multiple subdirectories - one for each class
ANNOTATIONS_PATH = "./via-annotations.json"  # relative to the notebook
YOLO_CONFIG_PATH = "yolo.yaml"

BATCH_SIZE = 8

CLASS_NAMES = sorted(list(set([os.path.basename(f).split(",")[0] for f in os.scandir(DATA_PATH) if f.is_dir()])))  # automatically generated
ALL_CLASS_NAMES = sorted(list(set([os.path.basename(f) for f in os.scandir(DATA_PATH) if f.is_dir()])))
CLASS_COUNT = len(CLASS_NAMES)

## Data size chart (treemap)

In [6]:
import os
import collections
import plotly.express as px
import plotly.io as pio
import pandas as pd

pio.renderers.default = "iframe"

subcategories = collections.defaultdict(list)

for klass in ALL_CLASS_NAMES:
    main, _, sub = klass.partition(", ")
    file_count = len(os.listdir(os.path.join(DATA_PATH, klass)))  # count files in each subcategory directory
    subcategories[main].append((sub, file_count))

categories = []
subcategories_list = []
values = []

for main, subs in subcategories.items():
    for sub, count in subs:
        categories.append(main)
        subcategories_list.append(sub)
        values.append(count)

df = pd.DataFrame({
    "Category": categories,
    "Subcategory": subcategories_list,
    "Number of samples": values
})

fig = px.treemap(
    df,
    path=[px.Constant("data"), "Category", "Subcategory"],
    values="Number of samples",
    title="Treemap of classes"
)

# The palette (from Google)
material_colours_hex = [
    "#009688",  # Teal
    "#4CAF50",  # Green
    "#8BC34A",  # Light Green
    "#CDDC39",  # Lime
    "#FFEB3B",  # Yellow
    "#FFC107",  # Amber
    "#FF9800",  # Orange
    "#FF5722",  # Deep Orange
    "#F44336",  # Red
    "#E91E63",  # Pink
    "#9C27B0",  # Purple
    "#673AB7",  # Deep Purple
    "#3F51B5",  # Indigo
    "#2196F3",  # Blue
    "#03A9F4",  # Light Blue
    "#00BCD4",  # Cyan
    "#795548",  # Brown
    "#9E9E9E",  # Grey
    "#607D8B",  # Blue Grey
]


fig.update_layout(
        width=2400, height=1200,
        treemapcolorway=material_colours_hex
)

fig.show()


## Data augmentation and conversion to YOLO format
We convert the data from the VGG Image Annotator format to the format used in the YOLO pipeline.

Data augmentation is also performed to increase the diversity of the dataset.
* Random rotation and flipping
* Random brightness
* Random contrast
* Random sharpness
* Random colour
* Random RGB deviation
* Salt and pepper noise
* JPEG compression

In [None]:
import json
import os
import shutil
import random
import math
from PIL import Image, UnidentifiedImageError, ImageEnhance, ImageOps
from tqdm import tqdm

def rotate_point(point, origin, angle):
    # Fast calculation for 90 degree rotations
    match angle % 360:
        case 0:
            return point
        case 90:
            return origin[1] + origin[0] - point[1], origin[1] - origin[0] + point[0]
        case 180:
            return 2 * origin[0] - point[0], 2 * origin[1] - point[1]
        case 270:
            return origin[1] - origin[0] + point[1], origin[1] + origin[0] - point[0]
    
    # Otherwise do math
    angle_rad = math.radians(angle)
    ox, oy = origin
    px, py = point
    
    qx = ox + math.cos(angle_rad) * (px - ox) - math.sin(angle_rad) * (py - oy)
    qy = oy + math.sin(angle_rad) * (px - ox) + math.cos(angle_rad) * (py - oy)
    
    return qx, qy

def save_photos(via_annotations_path, images_directory, output_directory, num_augmented_copies=3, save_quality=80):
    annotated_images = 0
    
    if os.path.exists(output_directory):
        shutil.rmtree(output_directory)
    os.makedirs(output_directory)

    with open(via_annotations_path, 'r') as f:
        data = json.load(f)
        
    rotation_angles = [0, 90, 180, 270]

    class_mapping = {}

    for image_id, image_data in tqdm(data["_via_img_metadata"].items()):
        filename = image_data["filename"]
        image_path = os.path.normpath(os.path.join(images_directory, filename))
        try:
            image = Image.open(image_path)
        except (UnidentifiedImageError, IOError):
            continue        # skip non-image or missing files

        image.verify()      # verify that it is, indeed, an image
        image.close()

        image = Image.open(image_path)
        w, h = image.size
        cx, cy = w / 2, h / 2

        class_name = os.path.basename(os.path.dirname(image_path)).split(",")[0]
        if class_name not in class_mapping:
            # Add new classes when encountered
            class_mapping[class_name] = len(class_mapping)
            
        # Skip images with no annotations
        if "regions" not in image_data or len(image_data['regions']) == 0:
            continue
        
        annotated_images += 1

        # Generate augmented copies, but always keep the original image as well
        for i in range(num_augmented_copies + 1):
            if not i:
                output_image_path = os.path.join(output_directory, os.path.basename(image_path))
            else:
                output_image_path = os.path.join(output_directory, os.path.splitext(os.path.basename(image_path))[0] + "_" + str(i) + os.path.splitext(image_path)[1])    # append suffix to avoid clashes
            transform_info_output_path = os.path.join(output_directory, os.path.splitext(os.path.basename(output_image_path))[0] + ".info.txt")
            if os.path.exists(transform_info_output_path):
                # This image was already transformed from another class, so we're adding the annotations from this one
                # as well
                with open(transform_info_output_path, "r") as info_file:
                    angle, flip = map(int, info_file.read().split())
            else:
                # Resize the image so the following operations are faster and the data size is reduced
                x_multiplier = 640/w
                y_multiplier = 640/h
                # image.thumbnail((640, 640), Image.HAMMING)
                image = image.resize((640, 640), Image.HAMMING)
                if not i:
                    # Keep the original image
                    angle = 0
                    flip = False
                    image.save(output_image_path)
                else:
                    angle = random.choice(rotation_angles)
                
                    # Random rotation
                    modified_image = image.rotate(angle, resample=Image.BICUBIC, expand=True)
                    # Random brightness
                    enhancer = ImageEnhance.Brightness(modified_image)
                    modified_image = enhancer.enhance(random.uniform(0.75, 1.25))
                    # Random contrast
                    enhancer = ImageEnhance.Contrast(modified_image)
                    modified_image = enhancer.enhance(random.uniform(0.625, 1.375))
                    # Random sharpness
                    enhancer = ImageEnhance.Sharpness(modified_image)
                    modified_image = enhancer.enhance(random.uniform(0.5, 1.5))
                    # Random colour
                    enhancer = ImageEnhance.Color(modified_image)
                    modified_image = enhancer.enhance(random.uniform(0.625, 1.375))
                    # Random horizontal flip
                    if random.random() < 0.5:
                        flip = True
                        modified_image = ImageOps.mirror(modified_image)
                    else:
                        flip = False
                    # Random RGB deviation
                    r_multiplier = random.uniform(0.875, 1.125)
                    g_multiplier = random.uniform(0.875, 1.125)
                    b_multiplier = random.uniform(0.875, 1.125)
                    r, g, b = modified_image.split()
                    r = r.point(lambda p: min(255, max(0, int(p * r_multiplier))))
                    g = g.point(lambda p: min(255, max(0, int(p * g_multiplier))))
                    b = b.point(lambda p: min(255, max(0, int(p * b_multiplier))))
                    modified_image = Image.merge("RGB", (r, g, b))
                    # Noise
                    try:
                        noise = Image.new("RGB", modified_image.size, (0, 0, 0))
                        for _ in range(int(image.width * image.height * 0.01)):
                            x, y = random.randint(0, image.width - 1), random.randint(0, image.height - 1)
                            # Salt or pepper?
                            if random.random() > 0.25:
                                noise.putpixel((x, y), (255, 255, 255))
                            else:
                                noise.putpixel((x, y), (0, 0, 0))
                        modified_image = Image.blend(modified_image, noise, 0.125)
                    except IndexError:
                        # Image is buggy
                        pass
                    modified_image.save(output_image_path, quality=save_quality)   # low quality to save space and diversify the dataset

            annotations_output_path = os.path.join(output_directory, os.path.splitext(os.path.basename(output_image_path))[0] + ".txt")
            with open(annotations_output_path, "a") as file:
                if "regions" not in image_data or len(image_data["regions"]) == 0:
                    continue  # skip unannotated images
                for region in image_data['regions']:
                    shape_attr = region['shape_attributes']
                    points = []
                    if shape_attr['name'] == 'rect':
                        x, y, width, height = shape_attr['x'], shape_attr['y'], shape_attr['width'], shape_attr['height']
                        points = [(x, y), (x + width, y), (x + width, y + height), (x, y + height)]
                    elif shape_attr['name'] == 'polygon':
                        points = list(zip(shape_attr['all_points_x'], shape_attr['all_points_y']))
                    elif shape_attr['name'] == 'circle':
                        cx, cy, r = shape_attr['cx'], shape_attr['cy'], shape_attr['r']
                        points = [(cx - r, cy - r), (cx + r, cy - r), (cx + r, cy + r), (cx - r, cy + r)]
                    elif shape_attr['name'] == 'ellipse':
                        cx, cy, rx, ry = shape_attr['cx'], shape_attr['cy'], shape_attr['rx'], shape_attr['ry']
                        angle += shape_attr.get("theta", 0)  # Add the rotation angle
                        points = [(cx - rx, cy - ry), (cx + rx, cy - ry), (cx + rx, cy + ry), (cx - rx, cy + ry)]

                    # Rotate points
                    if flip:
                        points = [(w - pt[0], pt[1]) for pt in points]
                    rotated_points = [rotate_point(pt, (cx, cy), angle) for pt in points]
                    min_x = max(0, min(pt[0] for pt in rotated_points))
                    max_x = min(w, max(pt[0] for pt in rotated_points))
                    min_y = max(0, min(pt[1] for pt in rotated_points))
                    max_y = min(h, max(pt[1] for pt in rotated_points))

                    x_center = ((min_x + max_x) / 2) / w
                    y_center = ((min_y + max_y) / 2) / h
                    box_width = (max_x - min_x) / w
                    box_height = (max_y - min_y) / h
                    file.write(f"{class_mapping[class_name]} {x_center} {y_center} {box_width} {box_height}\n")
            with open(transform_info_output_path, "w") as info_file:
                    # Copies of the image will transform points in the same way
                    info_file.write(f"{angle} {int(flip)}")
    
    return class_mapping, annotated_images

OUTPUT_PATH = "yolo_data/"
class_mapping, annotated_images = save_photos(ANNOTATIONS_PATH, DATA_PATH, OUTPUT_PATH)

We set the class names in the YOLO configuration file to the ones discovered while loading the data.

In [None]:
print(f"{annotated_images} images processed.")

In [None]:
from ruamel.yaml import YAML
print(len(class_mapping), class_mapping)

# Set the class names
yaml = YAML()
with open(YOLO_CONFIG_PATH, "r") as file:
    data = yaml.load(file)
    data["nc"] = len(class_mapping)
    data["names"] = class_mapping

## Training

In [None]:
from ultralytics import YOLO

model = YOLO("yolov8n.pt")

results = model.train(data=YOLO_CONFIG_PATH, epochs=72, verbose=True, imgsz=640)

In [None]:
torch.save(model.state_dict(), "model.pt")

## Inferencing
Retrieve the class mapping that was discovered earlier and used in training.

In [3]:
# Get the class mapping
from ruamel.yaml import YAML

yaml = YAML()

with open("yolo.yaml", "r") as file:
    data = yaml.load(file)
    class_mapping = data["names"]

### With live webcam feed

In [20]:
from ultralytics import YOLO
import os

model = YOLO("runs/detect/train26/weights/best.pt")

import cv2 as cv
import numpy as np
import torchvision.transforms as transforms
from PIL import Image, ImageFont, ImageDraw

import matplotlib.pyplot as plt
import matplotlib.patches as patches

# Start the webcam
import numpy as np
import math
cap = cv.VideoCapture(0)
cap.set(3, 640)
cap.set(4, 640)


font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf", 20)

while True:
    success, img = cap.read()
    results = model(img, stream=True, verbose=False)

    for r in results:
        boxes = r.boxes

        for box in boxes:
            confidence = round(float(box.conf[0] * 100))

            # Skip unsure detections
            if confidence < 0:
                continue
            
            # Get coordinates and draw the bounding box
            x1, y1, x2, y2 = box.xyxy[0]
            x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
            
            cv.rectangle(img, (x1, y1), (x2, y2), (255, 255, 255), 2)
            cv.rectangle(img, (x1-1, y1-1), (x2+1, y2+1), (136, 150, 0), 2)
            
            cls = int(box.cls[0])

            # Object label
            img_pil = Image.fromarray(img)
            draw = ImageDraw.Draw(img_pil)
            draw.text((x1, y1),
                      f"[{confidence}%] {class_mapping[cls]}",
                      font=font,
                      fill=(255, 255, 255),
                      stroke_width=1,
                      stroke_fill=(0, 0, 0))
            img = np.array(img_pil)


    cv.imshow("Webcam capture demo", img)
    if cv.waitKey(1) == 27:    # escape
        break

# Stop the webcam
cap.release()
cv.destroyAllWindows()

KeyboardInterrupt: 

### From a photo file

In [19]:
from ultralytics import YOLO
import os

model = YOLO("runs/detect/train26/weights/last.pt")

import cv2 as cv
import numpy as np
import torchvision.transforms as transforms
from PIL import Image, ImageFont, ImageDraw

import matplotlib.pyplot as plt
import matplotlib.patches as patches

import numpy as np
import math

font = ImageFont.truetype("~/.local/share/fonts/Roboto-Regular.ttf", 40)

image_dir = "./test_data/"
output_dir = "./output/"

for image_path in os.scandir(image_dir):
    img = cv.imread(image_path.path)
    results = model(img)
    
    for r in results:
        boxes = r.boxes
        
        # Geometric mean of the confidence values
        min_confidence = np.mean([box.conf[0] for box in boxes]) * 87.5
    
        for box in boxes:
            confidence = round(float(box.conf[0] * 100))
    
            # Skip unsure detections
            if confidence < min_confidence:
                continue
            
            # Get coordinates and draw the bounding box
            x1, y1, x2, y2 = box.xyxy[0]
            x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
            
            cv.rectangle(img, (x1, y1), (x2, y2), (255, 255, 255), 2)
            cv.rectangle(img, (x1-1, y1-1), (x2+1, y2+1), (136, 150, 0), 2)
            
            cls = int(box.cls[0])
    
            img_pil = Image.fromarray(img)
            draw = ImageDraw.Draw(img_pil)
            
            draw.text((x1, y1),
                      f"[{confidence}%] {class_mapping[cls]}",
                      font=font,
                      fill=(255, 255, 255),
                      stroke_width=1,
                      stroke_fill=(0, 0, 0))
            img = np.array(img_pil)
    
    cv.imwrite(os.path.join(output_dir, os.path.basename(image_path.path)), img)
    #cv.imshow("Photo", img)
    #cv.waitKey(27)         # escape
    #cv.destroyAllWindows()


0: 640x640 1 Glass, 21.5ms
Speed: 1.8ms preprocess, 21.5ms inference, 0.4ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 1 Plastic, 40.2ms
Speed: 1.2ms preprocess, 40.2ms inference, 0.5ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 (no detections), 27.3ms
Speed: 2.9ms preprocess, 27.3ms inference, 0.3ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 1 Paper, 65.2ms
Speed: 3.7ms preprocess, 65.2ms inference, 0.7ms postprocess per image at shape (1, 3, 640, 640)




Mean of empty slice.


invalid value encountered in double_scalars



0: 640x640 (no detections), 37.9ms
Speed: 4.0ms preprocess, 37.9ms inference, 0.3ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 (no detections), 31.0ms
Speed: 12.4ms preprocess, 31.0ms inference, 0.3ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 (no detections), 21.6ms
Speed: 1.7ms preprocess, 21.6ms inference, 0.3ms postprocess per image at shape (1, 3, 640, 640)

0: 480x640 1 Plastic, 22.5ms
Speed: 1.2ms preprocess, 22.5ms inference, 0.4ms postprocess per image at shape (1, 3, 480, 640)

0: 640x640 4 Plastics, 21.2ms
Speed: 1.7ms preprocess, 21.2ms inference, 0.4ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 3 Metals, 23.6ms
Speed: 1.8ms preprocess, 23.6ms inference, 0.7ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 (no detections), 21.8ms
Speed: 1.4ms preprocess, 21.8ms inference, 0.3ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 1 Plastic, 1 Glass, 22.5ms
Speed: 2.1ms preprocess, 22.5ms inference, 0.5ms