Random Forest Classifier using Scikit-learn
Last Updated :
30 May, 2025
Random Forest is a method that combines the predictions of multiple decision trees to produce a more accurate and stable result. It can be used for both classification and regression tasks.
In classification tasks, Random Forest Classification predicts categorical outcomes based on the input data. It uses multiple decision trees and outputs the label that has the maximum votes among all the individual tree predictions.
Random Forest ClassifierWorking of Random Forest Classifier
- Bootstrap Sampling: Random rows are picked (with replacement) to train each tree.
- Random Feature Selection: Each tree uses a random set of features (not all features).
- Build Decision Trees: Trees split the data using the best feature from their random set. Splitting continues until a stopping rule is met (like max depth).
- Make Predictions: Each tree gives its own prediction.
- Majority Voting: The final prediction is the one most tree agree on.
Benefits of Random Forest Classification:
- Random Forest can handle large datasets and high-dimensional data.
- By combining predictions from many decision trees, it reduces the risk of overfitting compared to a single decision tree.
- It is robust to noisy data and works well with categorical data.
Implementing Random Forest Classification in Python
Before implementing random forest classifier in Python let's first understand it's parameters.
- n_estimators: Number of trees in the forest.
- max_depth: Maximum depth of each tree.
- max_features: Number of features considered for splitting at each node.
- criterion: Function used to measure split quality ('gini' or 'entropy').
- min_samples_split: Minimum samples required to split a node.
- min_samples_leaf: Minimum samples required to be at a leaf node.
- bootstrap: Whether to use bootstrap sampling when building trees (True or False).
Now that we know it's parameters we can start building it in python.
1. Import Required Libraries
We will be importing Pandas, matplotlib, seaborn and sklearn to build the model.
python
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
2. Import Dataset
For this we'll use the Iris Dataset which is available within
sci-kit learn
. This dataset contains information about three types of Iris flowers and their respective features (sepal length, sepal width, petal length and petal width).
python
iris = load_iris()
df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
df['target'] = iris.target
df
Output:
Iris Dataset3. Data Preparation
Here we will separate the features (X) and the target variable (y).
python
X = df.iloc[:, :-1].values
y = df.iloc[:, -1].values
4. Splitting the Dataset
We'll split the dataset into training and testing sets so we can train the model on one part and evaluate it on another.
- X_train, y_train: 80% of the data used to train the model.
- X_test, y_test: 20% of the data used to test the model.
- test_size=0.2: means 20% of data goes to testing.
- random_state=42: ensures you get the same split every time
python
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
5. Feature Scaling
Feature scaling ensures that all the features are on a similar scale which is important for some machine learning models. However Random Forest is not highly sensitive to feature scaling. But it is a good practice to scale when combining models.
python
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
6. Building Random Forest Classifier
We will create the Random Forest Classifier model, train it on the training data and make predictions on the test data.
- RandomForestClassifier(n_estimators=100, random_state=42) creates 100 trees (100 trees balance accuracy and training time).
- classifier.fit(X_train, y_train) trains on training data.
- classifier.predict(X_test) predicts on test data.
- random_state=42 ensures reproducible results.
python
classifier = RandomForestClassifier(n_estimators=100, random_state=42)
classifier.fit(X_train, y_train)
y_pred = classifier.predict(X_test)
7. Evaluation of the Model
We will evaluate the model using the accuracy score and confusion matrix.
python
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy * 100:.2f}%')
conf_matrix = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='g', cmap='Blues', cbar=False,
xticklabels=iris.target_names, yticklabels=iris.target_names)
plt.title('Confusion Matrix Heatmap')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.show()
Output:
Accuracy: 100.00%
Confusion Matrix8. Feature Importance
Random Forest Classifiers also provide insight into which features were the most important in making predictions. We can plot the feature importance.
Python
feature_importances = classifier.feature_importances_
plt.barh(iris.feature_names, feature_importances)
plt.xlabel('Feature Importance')
plt.title('Feature Importance in Random Forest Classifier')
plt.show()
Output:
Feature Importance in Random ClassifierFrom the graph we can see that petal width (cm) is the most important feature followed closely by petal length (cm). The sepal width (cm) and sepal length (cm) have lower importance in determining the model’s predictions. This indicates that the classifier relies more on the petal measurements to make predictions about the flower species.
Random Forest can also be used for regression problem: Random Forest Regression in Python
Similar Reads
Python Tutorial | Learn Python Programming Language Python Tutorial â Python is one of the most popular programming languages. Itâs simple to use, packed with features and supported by a wide range of libraries and frameworks. Its clean syntax makes it beginner-friendly.Python is:A high-level language, used in web development, data science, automatio
10 min read
DSA Tutorial - Learn Data Structures and Algorithms DSA (Data Structures and Algorithms) is the study of organizing data efficiently using data structures like arrays, stacks, and trees, paired with step-by-step procedures (or algorithms) to solve problems effectively. Data structures manage how data is stored and accessed, while algorithms focus on
7 min read
Machine Learning Tutorial Machine learning is a branch of Artificial Intelligence that focuses on developing models and algorithms that let computers learn from data without being explicitly programmed for every task. In simple words, ML teaches the systems to think and understand like humans by learning from the data.It can
5 min read
Quick Sort QuickSort is a sorting algorithm based on the Divide and Conquer that picks an element as a pivot and partitions the given array around the picked pivot by placing the pivot in its correct position in the sorted array. It works on the principle of divide and conquer, breaking down the problem into s
12 min read
Merge Sort - Data Structure and Algorithms Tutorials Merge sort is a popular sorting algorithm known for its efficiency and stability. It follows the divide-and-conquer approach. It works by recursively dividing the input array into two halves, recursively sorting the two halves and finally merging them back together to obtain the sorted array. Merge
14 min read
Bubble Sort Algorithm Bubble Sort is the simplest sorting algorithm that works by repeatedly swapping the adjacent elements if they are in the wrong order. This algorithm is not suitable for large data sets as its average and worst-case time complexity are quite high.We sort the array using multiple passes. After the fir
8 min read
Data Structures Tutorial Data structures are the fundamental building blocks of computer programming. They define how data is organized, stored, and manipulated within a program. Understanding data structures is very important for developing efficient and effective algorithms. What is Data Structure?A data structure is a st
2 min read
Breadth First Search or BFS for a Graph Given a undirected graph represented by an adjacency list adj, where each adj[i] represents the list of vertices connected to vertex i. Perform a Breadth First Search (BFS) traversal starting from vertex 0, visiting vertices from left to right according to the adjacency list, and return a list conta
15+ min read
Binary Search Algorithm - Iterative and Recursive Implementation Binary Search Algorithm is a searching algorithm used in a sorted array by repeatedly dividing the search interval in half. The idea of binary search is to use the information that the array is sorted and reduce the time complexity to O(log N). Binary Search AlgorithmConditions to apply Binary Searc
15 min read
Insertion Sort Algorithm Insertion sort is a simple sorting algorithm that works by iteratively inserting each element of an unsorted list into its correct position in a sorted portion of the list. It is like sorting playing cards in your hands. You split the cards into two groups: the sorted cards and the unsorted cards. T
9 min read