Speed up Your ML Projects With Spark
Last Updated on June 25, 2024 by Editorial Team
Author(s): Mena Wang, PhD
Originally published on Towards AI.
Spark is an open-source distributed computing framework for high-speed data processing. It is widely supported by platforms like GCP and Azure, as well as Databricks, which was founded by the creators of Spark.
As a Python user, I find the {pySpark} library super handy for leveraging Sparkβs capacity to speed up data processing in machine learning projects. But here is a problem: While pySpark syntax is straightforward and very easy to follow, it can be readily confused with other common libraries for data wrangling. Please see a simple example below,
# Pandas:
import pandas as pd
df.groupby('category').agg(
avg_value1=('value1','mean'),
sum_value2=('value2','sum')
).reset_index()
# pySpark:
from pyspark.sql import functions as F
df.groupBy('category').agg(
F.mean('value1').alias('avg_value1'),
F.sum('value2').alias('sum_value2')
)
You see what I mean? Similar but different. 🤪 If you are like me, needing to write in multiple data-wrangling packages, including pySpark, and want to make life easier, this article is just for you! 🥂
To address this issue, and with the additional benefit of keeping our code neatly organized, easy to reuse and maintain, I created custom helper functions in pySpark that are stored in a separate file spark_utils.py
within each project folder. This practice vastly enhances the speed of my data preparation for machine learning projects. All you need to do is import them to where they are needed, like below
- my-project/
- EDA-demo.ipynb
- spark_utils.py
# then in EDA-demo.ipynb
import spark_utils as sut
I plan to share these helpful pySpark functions in a series of articles. This is the first one, where we look at some functions for data quality checks, which are the initial steps I take in EDA. Letβs get started. 🤠
🔗 All code and config are available on GitHub. 🧰
The dummy data
While Spark is famous for its ability to work with big data, for demo purposes, I have created a small dataset with an obvious duplicate issue. Do you notice that the two ID fields, ID1 and ID2, do not form a primary key? We will use this table to demo and test our custom functions.
Shape
I find the shape
attribute of pandas dataframes is pretty convenient, therefore created a custom function to get the shape of spark dataframes too. A few things to note:
- This custom
shape
function prints out comma-formated numbers, which can be especially helpful for big datasets. - It can return the shape tuple for further programmatic use when the
print_only
parameter is set toFalse
.
BTW, you might be delighted to learn that all the functions in this article are equipped with 1) Docstring
documentation and 2) Type hints. You are welcome 😁
def shape(df: DataFrame, print_only: bool = True):
"""
Get the number of rows and columns in the DataFrame.
Parameters:
- df (DataFrame): The DataFrame to get the shape of.
- print_only (bool): If True, only print out the shape. Default is True.
Returns:
- tuple or None: (num_rows, num_cols) if print_only is False, otherwise None
"""
num_rows = df.count()
num_cols = len(df.columns)
print(f"Number of rows: {num_rows:,}")
print(f"Number of columns: {num_cols:,}")
if print_only:
return None
else:
return num_rows, num_cols
Print schema
In pySpark, there is a built-in printSchema
function. However, when working with very wide tables, I prefer to have the column names sorted alphabetically so I can check for things more effectively. Here is the function for that.
def print_schema_alphabetically(df: DataFrame):
"""
Prints the schema of the DataFrame with columns sorted alphabetically.
Parameters:
- df (DataFrame): The DataFrame whose schema is to be printed.
Returns:
None
"""
sorted_columns = sorted(df.columns)
sorted_df = df.select(sorted_columns)
sorted_df.printSchema()
Verify primary key
A common EDA task is to check the primary key(s) and troubleshoot for duplicates. The three functions below are created for this purpose. First, letβs look at the is_primary_key
function. As its name indicates, this function checks if the specified column(s) forms a primary key in the DataFrame. A few things to note
- It returns
False
when the dataframe is empty, or when any of the specified columns are missing from the dataframe. - It checks for missing values in any of the specified columns and excludes relevant rows from the row counts.
- Using the
verbose
parameter, users can specify whether to print out or suppress detailed info during the function run.
def is_primary_key(df: DataFrame, cols: List[str], verbose: bool = True) -> bool:
"""
Check if the combination of specified columns forms
a primary key in the DataFrame.
Parameters:
df (DataFrame): The DataFrame to check.
cols (list): A list of column names to check for forming a primary key.
verbose (bool): If True, print detailed information. Default is True.
Returns:
bool: True if the combination of columns forms a primary key, False otherwise.
"""
# Check if the DataFrame is not empty
if df.isEmpty():
if verbose:
print("DataFrame is empty.")
return False
# Check if all columns exist in the DataFrame
missing_cols = [col_name for col_name in cols if col_name not in df.columns]
if missing_cols:
if verbose:
print(f"Column(s) {', '.join(missing_cols)} do not exist in the DataFrame.")
return False
# Check for missing values in each specified column
for col_name in cols:
missing_rows_count = df.where(col(col_name).isNull()).count()
if missing_rows_count > 0:
if verbose:
print(f"There are {missing_rows_count:,} row(s) with missing values in column '{col_name}'.")
# Filter out rows with missing values in any of the specified columns
filtered_df = df.dropna(subset=cols)
# Check if the combination of columns is unique after filtering out missing value rows
unique_row_count = filtered_df.select(*cols).distinct().count()
total_row_count = filtered_df.count()
if verbose:
print(f"Total row count after filtering out missings: {total_row_count:,}")
print(f"Unique row count after filtering out missings: {unique_row_count:,}")
if unique_row_count == total_row_count:
if verbose:
print(f"The column(s) {', '.join(cols)} form a primary key.")
return True
else:
if verbose:
print(f"The column(s) {', '.join(cols)} do not form a primary key.")
return False
Find duplicates
Consistent with our inspection of the dummy table, the two ID fields do not form a primary key. Of course, duplicates can exist in real data too, below is the function to identify them. 🔎
def find_duplicates(df: DataFrame, cols: List[str]) -> DataFrame:
"""
Function to find duplicate rows based on specified columns.
Parameters:
- df (DataFrame): The DataFrame to check.
- cols (list): List of column names to check for duplicates
Returns:
- duplicates (DataFrame): PySpark DataFrame containing duplicate rows based on the specified columns,
with the specified columns and the 'count' column as the first columns,
along with the rest of the columns from the original DataFrame
"""
# Filter out rows with missing values in any of the specified columns
for col_name in cols:
df = df.filter(col(col_name).isNotNull())
# Group by the specified columns and count the occurrences
dup_counts = df.groupBy(*cols).count()
# Filter to retain only the rows with count greater than 1
duplicates = dup_counts.filter(col("count") > 1)
# Join with the original DataFrame to include all columns
duplicates = duplicates.join(df, cols, "inner")
# Reorder columns with 'count' as the first column
duplicate_cols = ['count'] + cols
duplicates = duplicates.select(*duplicate_cols, *[c for c in df.columns if c not in cols])
return duplicates
Columns responsible for the duplicates
From the above table, it is fairly easy to tell which columns are responsible for duplications in our data.
- 🔎 The
City
column is responsible for the differences in 101-A and 102-B ID combinations. For example, the dup in 101-A is because the City is recorded both as βNew Yorkβ and βNYβ. - 🔎 The
Name
column is responsible for the difference in the 105-B ID combination, where the personβs name is βBingoβ in one record and βBinggyβ in another.
Identifying the root cause of the dups is important for troubleshooting. For instance, based on the discovery above, we should consolidate both city and person names in our data.
You can imagine that when we have very wide tables and many more dups, identifying and summarizing the root cause using human eyes 👀 like we did above becomes much trickier.
The cols_responsible_for_id_dups
function comes in rescue by summarizing the difference_counts
for each column based on the primary key(s) provided. 😎 For example, in the output below, we can easily see that the field City
is responsible for differences in two unique ID combinations, while the Name
column is responsible for the dups in one ID pair.
def cols_responsible_for_id_dups(spark_df: DataFrame, id_list: List[str]) -> DataFrame:
"""
This diagnostic function checks each column
for each unique id combinations to see whether there are differences.
This can be used to identify columns responsible for most duplicates
and help with troubleshooting.
Parameters:
- spark_df (DataFrame): The Spark DataFrame to analyze.
- id_list (list): A list of column names representing the ID columns.
Returns:
- summary_table (DataFrame): A Spark DataFrame containing two columns
'col_name' and 'difference_counts'.
It represents the count of differing values for each column
across all unique ID column combinations.
"""
# Get or create the SparkSession
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
# Filter out rows with missing values in any of the ID columns
filtered_df = spark_df.na.drop(subset=id_list)
# Define a function to count differences within a column for unique id_list combinations
def count_differences(col_name):
"""
Counts the number of differing values for each col_name.
Args:
- col_name (str): The name of the column to analyze.
Returns:
- count (int): The count of differing values.
"""
# Count the number of distinct values for each combination of ID columns and current column
distinct_count = filtered_df.groupBy(*id_list, col_name).count().groupBy(*id_list).count()
return distinct_count.filter(col("count") > 1).count()
# Get the column names excluding the ID columns
value_cols = [col_name for col_name in spark_df.columns if col_name not in id_list]
# Create a DataFrame to store the summary table
summary_data = [(col_name, count_differences(col_name)) for col_name in value_cols]
summary_table = spark.createDataFrame(summary_data, ["col_name", "difference_counts"])
# Sort the summary_table by "difference_counts" from large to small
summary_table = summary_table.orderBy(col("difference_counts").desc())
return summary_table
The columns responsible for the most duplicates are listed at the top of the summary table. We can then analyze these columns further for troubleshooting. If you have a very wide table, narrowing down the investigation like this can be pretty handy.
Value counts with percentage
For example, you can zoom in by checking relevant columnsβ value counts among the dups. And of course, I have a custom function ready for you to do just this. 😜 This function is very much like the value_counts
in pandas, with two additional features
- percentage for each unique value
- comma-formated numbers in the printout
Letβs see it in action
def value_counts_with_pct(df:DataFrame, column_name:str) -> DataFrame:
"""
Calculate the count and percentage of occurrences for each unique value
in the specified column.
Parameters:
- df (DataFrame): The DataFrame containing the data.
- column_name (str): The name of the column for which to calculate value counts.
Returns:
- DataFrame: A DataFrame containing two columns: the unique values in the specified column and their corresponding count and percentage.
"""
counts = df.groupBy(column_name).agg(
count("*").alias("count"),
(count("*") / df.count() * 100).alias("pct")
)
counts = counts.withColumn("pct", round(col("pct"), 2))
counts = counts.orderBy(col("count").desc())
# Format count column with comma spacing for printing
formatted_counts = counts.withColumn("count", format_string("%,d", col("count")))
formatted_counts.show()
# Return counts DataFrame with raw numbers
return counts
So here you go, above are some common functions we can use to streamline and speed up data preparation for ML projects with pySpark. Let me know if you like them, I have many more to share if you find them helpful. 😎
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI