Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions support_vector_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
# Fitting the classifier into the Training set

from sklearn.svm import SVC
classifier = SVC(kernel = 'linear', random_state = 0)
kernel_type = 'rbf' # options: 'linear', 'rbf', 'poly', 'sigmoid'
classifier = SVC(kernel=kernel_type, random_state=0)

classifier.fit(X_Train, Y_Train)

# Predicting the test set results
Expand All @@ -37,7 +39,14 @@

from sklearn.metrics import confusion_matrix
cm = confusion_matrix(Y_Test, Y_Pred)
print("Confusion Matrix:\n", cm)

# Add accuracy and classification report
from sklearn.metrics import accuracy_score, classification_report
print("Accuracy:", accuracy_score(Y_Test, Y_Pred))
print("\nClassification Report:\n", classification_report(Y_Test, Y_Pred))

# Step 7: Visualizing the Training Results
# Visualising the Training set results

from matplotlib.colors import ListedColormap
Expand All @@ -51,9 +60,11 @@
for i, j in enumerate(np.unique(Y_Set)):
plt.scatter(X_Set[Y_Set == j, 0], X_Set[Y_Set == j, 1],
c = ListedColormap(('red', 'green'))(i), label = j)
plt.title('Support Vector Machine (Training set)')
plt.xlabel('Age')
plt.ylabel('Estimated Salary')
plt.title(f'Support Vector Machine ({kernel_type.capitalize()} Kernel) - Training set')

plt.xlabel('Age (years)')
plt.ylabel('Estimated Salary ($)')

plt.legend()
plt.show()

Expand Down