Skip to content

Commit 58daca6

Browse files
ImageDataGenerator
Image Data Generator Made Using Kera's
1 parent 678621a commit 58daca6

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

ML Project/ImageDataGenerator

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Import necessary libraries
2+
import tensorflow as tf
3+
from tensorflow.keras.preprocessing.image import ImageDataGenerator
4+
from tensorflow.keras.models import Sequential
5+
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
6+
7+
# Set up data directories
8+
train_dir = 'train'
9+
test_dir = 'test'
10+
11+
# Data Preprocessing
12+
train_datagen = ImageDataGenerator(rescale=1./255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True)
13+
test_datagen = ImageDataGenerator(rescale=1./255)
14+
15+
train_generator = train_datagen.flow_from_directory(train_dir, target_size=(64, 64), batch_size=32, class_mode='binary')
16+
test_generator = test_datagen.flow_from_directory(test_dir, target_size=(64, 64), batch_size=32, class_mode='binary')
17+
18+
# Build a Convolutional Neural Network (CNN) model
19+
model = Sequential()
20+
model.add(Conv2D(32, (3, 3), input_shape=(64, 64, 3), activation='relu'))
21+
model.add(MaxPooling2D(pool_size=(2, 2)))
22+
model.add(Conv2D(64, (3, 3), activation='relu'))
23+
model.add(MaxPooling2D(pool_size=(2, 2)))
24+
model.add(Flatten())
25+
model.add(Dense(units=128, activation='relu'))
26+
model.add(Dense(units=1, activation='sigmoid')
27+
28+
# Compile the model
29+
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
30+
31+
# Train the model
32+
model.fit(train_generator, steps_per_epoch=len(train_generator), epochs=10, validation_data=test_generator, validation_steps=len(test_generator))
33+
34+
# Evaluate the model
35+
test_loss, test_accuracy = model.evaluate(test_generator, steps=len(test_generator))
36+
print("Test accuracy: {:.2f}%".format(test_accuracy * 100))
37+
38+
# Save the model
39+
model.save('cat_dog_classifier.h5')

0 commit comments

Comments
 (0)