Lecture03. Classification (Chapter 3)
Lecture03. Classification (Chapter 3)
Classification(Chapter 3)
FALL 2024
ML Algorithms
DA 515
DA 535
2
Deep Learning: more resources
3
Main Points of Chapter 3
Classification:
Binary (1/0. or True/False)
Multiclass (0, 1, …, 9)
Evaluation Metrics:
Cross-Validation
Accuracy, Precision, Recall, F1
ROC/AUC
4
MNIST: Modified National Institute of Standards and Technology
70,000 small images of digits handwritten
Each image is labeled : 0, 1, …, 9
5
Load in dataset
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1)
Read in from a dictionary : data key and target key
rows
One Image: digital 5
# from 1-d to 2-d
X[0].reshape(28, 28)
[ [ 0 0 0 ………………0],
[ 0 0 0 ………… ......0],
[…………… 64 ……],
[0 ……………………..0] ]
8
We focus on Evaluating of ML Models
For now, we skip the following steps:
Missing value handling
Outlier detection
Data Scaling
Feature Selection
PCA
Imbalanced data
….
9
Split data sets: Training vs. Testing
X_train, X_test, y_train, y_test =
X[:60000], X[60000:], y[:60000], y[60000:]
10
1. Binary Classification
Simple Binary problems: one-versus-the-rest (OvR)
True (class 1) if it is number 5
False (class 0) if not
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)
or
Mini-Batch: use a bunch of samples to update.
12
SGD vs. Mini-Batch vs. Full
13
1.1 SGD: stochastic gradient descent
Classifier: Stochastic Gradient Descent (SGD)
Built-in in Scikit-Learn
from sklearn.linear_model import SGDClassifier
14
Performance Measures
Accuracy: corrected/ALL
Cross Validation
Confusion Matrix
Precision: positive predictions
Recall(sensitivity): true positive rate
F1: harmonic mean of precision and recall
ROC/AUC
15
Test: Accuracy = corrected/All
print(sgd_clf.predict(X_test)[0:50]) # predicted for the first 50
print(y_test_5[0:50]) # real label for the first 50
Predicted
[False False False False False False False False False False False False
False False False True False False False False False False False True
False False False False False False False False False False False False
False False False False False False False False False True False False
False False]
Real
[False False False False False False False False True False False False
False False False True False False False False False False False True
False False False False False False False False False False False False
False False False False False False False False False True False False
False False]
16
imbalanced Data
Why the accuracy is so high (>95%)
Balanced data:
T:F(50:50), 50% accuracy is the baseline.
17
Try models: Cross Validation
Fold = 3 applied to Training data
Sampling => stratified
skfolds = StratifiedKFold(n_splits=3,
random_state=42)
OUTPUT:
#0.95035, 0.96035, 0.9604
18
Confusion Matrix
Accuracy = (TP+TN)/ALL
Multiple classes
19
Confusion Matrix Code
from sklearn.metrics import confusion_matrix
confusion_matrix(y_train_5, y_train_pred)
# output
array([[53892, 687],
[ 1891, 3530]], dtype=int64)
21
Other metrics
Precision:
positive predictions
Recall(sensitivity)
true positive rate
precision_score(y_train_5, y_train_pred) # ==
4096 / (4096 + 1522), 0.7290850836596654
When it claims an image represents a
5, it is correct only 72.9% of the
time
recall_score(y_train_5, y_train_pred) # ==
4096 / (4096 + 1325) , 0.7555801512636044
it only detects 75.6% of the real 5s.
22
Metric F1
#0.7325171197343846
23
Classification Report
24
Precision Vs Recall
(changing for different thresholds)
25
Decision Threshold (best: trade-off)
True 5: total 6
26
Another Example: continuous distribution
27
ROC
The receiver operating characteristic (ROC) curve
28
ROC curve
29
Area Under the Curve (AUC)
30
ROC for different Algorithms
SGD Random Forest
AUC = 0.9604938554008616 AUC = 0.9983436731328145
31
Change threshold: prediction changes
(5 or not 5)
y_scores
# array([2164.22030239])
needs, and
use ROC curves and ROC AUC scores to compare
various models.
33
2. Multiclass Classification (0 1 … 9)
Some algorithms (such as SGD classifiers, Random Forest
classifiers, and naive Bayes
classifiers) are capable of handling multiple classes natively.
Others (such as Logistic Regression or Support Vector Machine
Algorithms:
(1) SDG
(2) Forest
34
(2.1) SGD
The decision_function() method now returns one value per class.
sgd_clf.decision_function([some_digit])
# array([[-15955.22628, -38080.96296, -13326.66695,
573.52692, -17680.68466, 2412.53175, -25526.86498, -
12290.15705, -7946.05205, -10631.35889]])
scaler = StandardScaler()
X_train_scaled =
scaler.fit_transform(X_train.astype(np.float64))
36
2.2 RandomForestClassifier
Just switch the classifier from sgd to Random Forest:
# just swith the classifier from sgd to Random Forest
from sklearn.ensemble import RandomForestClassifier
forest_clf = RandomForestClassifier(random_state=42)
37
If this were a real project
Error Analysis:
checklist (see Appendix B).
try out multiple models
fine-tuning their hyperparameters using GridSearchCV
CONFUSION MATRIX
38
Vizualization
plt.matshow(conf_mx, cmap=plt.cm.gray)
plt.show()
39
SGD model: examples of 3s and 5s
Bad handwriting
40
3. (skip) Multilabel Classification
One sample: assigned to one class
multilabel classification for each instance:
Example: face recognition
three faces: Alice, Bob, and Charlie.
Classifier for picture of Alice and Charlie:
[1, 0, 1]
(meaning “Alice yes, Bob no, Charlie yes”).
Classifier outputs: multiple binary tags
41
4. (skip) Multioutput Classification
multioutput–multiclass classification (or simply multioutput
classification).
It is simply a generalization of multilabel classification
where each label can be multiclass (i.e., it can have more
than two possible values).
noise from images. It will take as input a noisy digit image,
and it will (hopefully) output a clean digit image,
represented as an array of pixel intensities, just like the
MNIST images.
Notice that the classifier’s output is multilabel (one label
per pixel) and each label can have multiple values (pixel
intensity ranges from 0 to 255). It is thus an example of a
multioutput classification system. 42
Summary
Classification
Binary (yes, no)
Multiclass (0, 1, …9)
Evaluation Metrics:
Confusion Matrix
Accuracy
ROC/AUC
Cross-Validation
43
Trade-off
44
Optional HW: Data augmentation
(no turning in)
45
END