roundabout,
created on Monday, 28 August 2023, 13:23:38 (1693229018),
received on Tuesday, 26 March 2024, 14:36:52 (1711463812)
Author identity: vlad <vlad.muntoiu@gmail.com>
b8037a7bfa24bb21d1b1a92e1dc5f60e6837b354
main.ipynb
@@ -0,0 +1,176 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "57b7c687",
"metadata": {},
"source": [
"# LiteWaste training notebook"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c32d318b",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
"from tensorflow.keras.applications import VGG16\n",
"from tensorflow.keras.layers import Dense, GlobalAveragePooling2D\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.optimizers import Adam"
]
},
{
"cell_type": "markdown",
"id": "be23d629",
"metadata": {},
"source": [
"## constants"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8cecf74e",
"metadata": {},
"outputs": [],
"source": [
"DATA_PATH = \"./data/\"\n",
"BATCH_SIZE = 32\n",
"IMAGE_RESOLUTION = (224, 224)\n",
"EPOCHS = 40\n",
"CLASS_COUNT = len(os.listdir(dataset_dir))"
]
},
{
"cell_type": "markdown",
"id": "c868fd1f",
"metadata": {},
"source": [
"## data preparation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "be9c5e4a",
"metadata": {},
"outputs": [],
"source": [
"datagen = ImageDataGenerator(\n",
" rescale=1./255,\n",
" rotation_range=45,\n",
" width_shift_range=0.125,\n",
" height_shift_range=0.125,\n",
" shear_range=0.25,\n",
" zoom_range=0.2,\n",
" horizontal_flip=True,\n",
" fill_mode=\"nearest\",\n",
" brightness_range=[0.75, 1.25],\n",
" channel_shift_range=16,\n",
")\n",
"\n",
"augmented_images = []\n",
"for i in range(4):\n",
" augmented_images.extend(datagen.flow_from_directory(\n",
" DATA_PATH,\n",
" target_size=IMAGE_RESOLUTION,\n",
" batch_size=BATCH_SIZE,\n",
" class_mode=\"categorical\",\n",
" ))"
]
},
{
"cell_type": "markdown",
"id": "3c98bbd7",
"metadata": {},
"source": [
"## model definition"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b59988d4",
"metadata": {},
"outputs": [],
"source": [
"base_model = VGG19(weights=\"imagenet\", include_top=False)\n",
"x = base_model.output\n",
"x = GlobalAveragePooling2D()(x)\n",
"x = Dense(1024, activation=\"relu\")(x)\n",
"predictions = Dense(num_classes, activation=\"softmax\")(x)\n",
"model = Model(inputs=base_model.input, outputs=predictions)\n",
"\n",
"for layer in base_model.layers:\n",
" layer.trainable = False"
]
},
{
"cell_type": "markdown",
"id": "cb1380e7",
"metadata": {},
"source": [
"## compilation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b3acdd4",
"metadata": {},
"outputs": [],
"source": [
"model.compile(optimizer=Adam(learning_rate=0.0001),\n",
" loss=\"categorical_crossentropy\",\n",
" metrics=[\"accuracy\"])"
]
},
{
"cell_type": "markdown",
"id": "8fc14089",
"metadata": {},
"source": [
"## training"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8ef0a94",
"metadata": {},
"outputs": [],
"source": [
"history = model.fit(\n",
" train_generator,\n",
" steps_per_epoch=train_generator.samples // BATCH_SIZE,\n",
" epochs=EPOCHS,\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}