PySpark Linear Regression Get Coefficients
Last Updated :
28 Apr, 2025
In this tutorial series, we are going to cover Linear Regression using Pyspark. Linear Regression is a machine learning algorithm that is used to perform regression methods. Linear Regression is a supervised machine learning algorithm where we know inputs as well as outputs.
Loading Dataframe :
We will be using the data for "E-commerce Customer Data for a company's website and mobile app". The task is to predict the customer's yearly spending on the company's product.
Dataset link: [https://www.kaggle.com/datasets/pawankumargunjan/ecommercecustomers]
Step 1: Starting the Pyspark Server:
Python3
# Starting the Spark Session
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('LinearRegression').getOrCreate()
spark
Output:
SparkSession - in-memory
SparkContext
Spark UI
Version
v3.3.1
Master
local[*]
AppName
LinearRegression
Step 2: Load the dataset:
Python3
# Reading the data
df = spark.read.csv('Ecommerce_Customers.csv',inferSchema=True, header=True)
# Showing the data
df.show(5)
Output:
+--------------------+--------------------+----------------+------------------+------------------+------------------+--------------------+-------------------+
| Email| Address| Avatar|Avg Session Length| Time on App| Time on Website|Length of Membership|Yearly Amount Spent|
+--------------------+--------------------+----------------+------------------+------------------+------------------+--------------------+-------------------+
|mstephenson@ferna...|835 Frank TunnelW...| Violet| 34.49726772511229| 12.65565114916675| 39.57766801952616| 4.0826206329529615| 587.9510539684005|
| [email protected]|4547 Archer Commo...| DarkGreen| 31.92627202636016|11.109460728682564|37.268958868297744| 2.66403418213262| 392.2049334443264|
| [email protected]|24645 Valerie Uni...| Bisque|33.000914755642675|11.330278057777512|37.110597442120856| 4.104543202376424| 487.54750486747207|
|riverarebecca@gma...|1414 David Throug...| SaddleBrown| 34.30555662975554|13.717513665142507| 36.72128267790313| 3.120178782748092| 581.8523440352177|
|mstephens@davidso...|14023 Rodriguez P...|MediumAquaMarine| 33.33067252364639|12.795188551078114| 37.53665330059473| 4.446308318351434| 599.4060920457634|
+--------------------+--------------------+----------------+------------------+------------------+------------------+--------------------+-------------------+
only showing top 5 rows
Step 3: Check the columns name
Python3
#Shows the columns of the data
df.columns
Output:
['Email',
'Address',
'Avatar',
'Avg Session Length',
'Time on App',
'Time on Website',
'Length of Membership',
'Yearly Amount Spent']
Step 4: The next task is to assemble the data in form of vectors which will be the "features".
Python3
from pyspark.ml.feature import VectorAssembler
assembler = VectorAssembler(
inputCols=['Avg Session Length', "Time on App", "Time on Website", 'Length of Membership'],
outputCol="features")
output = assembler.transform(df)
output.select("features").show(5)
Output:
+--------------------+
| features|
+--------------------+
|[34.4972677251122...|
|[31.9262720263601...|
|[33.0009147556426...|
|[34.3055566297555...|
|[33.3306725236463...|
+--------------------+
only showing top 5 rows
Step 5: Split the whole data into train data and test data which will be used for training and testing respectively.
Python3
final_data = output.select("features",'Yearly Amount Spent')
train_data,test_data = final_data.randomSplit([0.7,0.3])
Let's describe the train data and test data.
Python3
train_data.describe().show()
test_data.describe().show()
Output:
+-------+-------------------+
|summary|Yearly Amount Spent|
+-------+-------------------+
| count| 357|
| mean| 496.7071530755217|
| stddev| 80.03111843524778|
| min| 256.67058229005585|
| max| 765.5184619388373|
+-------+-------------------+
+-------+-------------------+
|summary|Yearly Amount Spent|
+-------+-------------------+
| count| 143|
| mean| 505.82213623310577|
| stddev| 77.39011604239676|
| min| 275.9184206503857|
| max| 744.2218671047146|
+-------+-------------------+
Step 6: create a model for Linear Regression and fit it on training data.
Python3
from pyspark.ml.regression import LinearRegression
# Create a Linear Regression Model object
lr = LinearRegression(labelCol='Yearly Amount Spent')
# Fit the model to the data and call this model lrModel
lrModel = lr.fit(train_data)
lrModel
Output:
LinearRegressionModel: uid=LinearRegression_74214a54e364, numFeatures=4
Step 7: Print the coefficient and Intercept of the model
Python3
# Print the coefficients and intercept for linear regression
print("Coefficients: {}".format(lrModel.coefficients))
print('Intercept: {}'.format(lrModel.intercept))
Output:
Coefficients: [25.964105285025216,38.93669968512164,0.2862951403317341,61.42916517189798]
Intercept: -1055.4964671721655
Step 8: Evaluation of model on test data:
Python3
test_results = lrModel.evaluate(test_data)
#Printing Residuals which is the difference between the actua
#l value and the value predicted by the model (y-ŷ) for any given point
test_results.residuals.show(5)
Output:
+-------------------+
| residuals|
+-------------------+
| 11.275316471318774|
| 0.6070843579793177|
| 6.966802347383464|
| -6.151576882623033|
|-7.3822955579703375|
+-------------------+
only showing top 5 rows
Step 9: Prediction on new dataset
Python3
unlabeled_data = test_data.select('features')
predictions = lrModel.transform(unlabeled_data)
predictions.show(5)
Output:
+--------------------+------------------+
| features| prediction|
+--------------------+------------------+
|[29.5324289670579...| 397.3650346013087|
|[30.5743636841713...|441.45732940008634|
|[30.9716756438877...|487.67180740950926|
|[31.0613251567161...|493.70703494052464|
|[31.1280900496166...| 564.634982305025|
+--------------------+------------------+
only showing top 5 rows
Step 10: Calculating Root Mean Squared Error and Mean Squared Error for checking the efficiency of our model:
Python3
print("RMSE: {}".format(test_results.rootMeanSquaredError))
print("MSE: {}".format(test_results.meanSquaredError))
Output:
RMSE: 9.965510046039142
MSE: 99.31139047770706
Step 11: Stop the session
Python3
Similar Reads
Extracting Regression Coefficients from statsmodels.api data analysis and machine learning, regression analysis is a fundamental tool used to understand relationships between variables. Python's statsmodels library provides a powerful framework for performing regression analysis. This article delves into how to extract regression coefficients using stats
3 min read
Logistic Regression using PySpark Python In this tutorial series, we are going to cover Logistic Regression using Pyspark. Logistic Regression is one of the basic ways to perform classification (donât be confused by the word âregressionâ). Logistic Regression is a classification method. Some examples of classification are: Spam detectionDi
3 min read
How to Check PySpark Version Knowing the version of PySpark you're working with is crucial for compatibility and troubleshooting purposes. In this article, we will walk through the steps to check the PySpark version in the environment.What is PySpark?PySpark is the Python API for Apache Spark, a powerful distributed computing s
3 min read
PySpark Collect() â Retrieve data from DataFrame Collect() is the function, operation for RDD or Dataframe that is used to retrieve the data from the Dataframe. It is used useful in retrieving all the elements of the row from each partition in an RDD and brings that over the driver node/program. So, in this article, we are going to learn how to re
6 min read
Select columns in PySpark dataframe In this article, we will learn how to select columns in PySpark dataframe. Function used: In PySpark we can select columns using the select() function. The select() function allows us to select single or multiple columns in different formats. Syntax: dataframe_name.select( columns_names ) Note: We
4 min read
Convert pair to value using map() in Pyspark In this article, we are going to learn how to use map() to convert (key, value) pair to value and keys only using Pyspark in Python. PySpark is the Python library for Spark programming. It is an API for interacting with the Spark cluster using the Python programming language. PySpark provides a simp
3 min read