Unlock the full potential of AI with Building LLMs for Production—our 470+ page guide to mastering LLMs with practical projects and expert insights!


Speed up Your ML Projects With Spark
Latest   Machine Learning

Speed up Your ML Projects With Spark

Last Updated on June 25, 2024 by Editorial Team

Author(s): Mena Wang, PhD

Originally published on Towards AI.

Speed up Your ML Projects With Spark
Image generated by Gemini

Spark is an open-source distributed computing framework for high-speed data processing. It is widely supported by platforms like GCP and Azure, as well as Databricks, which was founded by the creators of Spark.

As a Python user, I find the {pySpark} library super handy for leveraging Spark’s capacity to speed up data processing in machine learning projects. But here is a problem: While pySpark syntax is straightforward and very easy to follow, it can be readily confused with other common libraries for data wrangling. Please see a simple example below,

# Pandas:
import pandas as pd


# pySpark:
from pyspark.sql import functions as F


You see what I mean? Similar but different. 🤪 If you are like me, needing to write in multiple data-wrangling packages, including pySpark, and want to make life easier, this article is just for you! 🥂

To address this issue, and with the additional benefit of keeping our code neatly organized, easy to reuse and maintain, I created custom helper functions in pySpark that are stored in a separate file spark_utils.py within each project folder. This practice vastly enhances the speed of my data preparation for machine learning projects. All you need to do is import them to where they are needed, like below

- my-project/
- EDA-demo.ipynb
- spark_utils.py
# then in EDA-demo.ipynb
import spark_utils as sut

I plan to share these helpful pySpark functions in a series of articles. This is the first one, where we look at some functions for data quality checks, which are the initial steps I take in EDA. Let’s get started. 🤠

🔗 All code and config are available on GitHub. 🧰

The dummy data

While Spark is famous for its ability to work with big data, for demo purposes, I have created a small dataset with an obvious duplicate issue. Do you notice that the two ID fields, ID1 and ID2, do not form a primary key? We will use this table to demo and test our custom functions.

The dummy table for demo


I find the shape attribute of pandas dataframes is pretty convenient, therefore created a custom function to get the shape of spark dataframes too. A few things to note:

  • This custom shape function prints out comma-formated numbers, which can be especially helpful for big datasets.
  • It can return the shape tuple for further programmatic use when the print_only parameter is set to False.

BTW, you might be delighted to learn that all the functions in this article are equipped with 1) Docstring documentation and 2) Type hints. You are welcome 😁

def shape(df: DataFrame, print_only: bool = True):
Get the number of rows and columns in the DataFrame.

- df (DataFrame): The DataFrame to get the shape of.
- print_only (bool): If True, only print out the shape. Default is True.

- tuple or None: (num_rows, num_cols) if print_only is False, otherwise None

num_rows = df.count()
num_cols = len(df.columns)
print(f"Number of rows: {num_rows:,}")
print(f"Number of columns: {num_cols:,}")
if print_only:
return None
return num_rows, num_cols
Demo output of the sut.shape() function

Print schema

In pySpark, there is a built-in printSchema function. However, when working with very wide tables, I prefer to have the column names sorted alphabetically so I can check for things more effectively. Here is the function for that.

def print_schema_alphabetically(df: DataFrame):
Prints the schema of the DataFrame with columns sorted alphabetically.

- df (DataFrame): The DataFrame whose schema is to be printed.


sorted_columns = sorted(df.columns)
sorted_df = df.select(sorted_columns)
Demo output of the sut.print_schema_alphabetically() function

Verify primary key

A common EDA task is to check the primary key(s) and troubleshoot for duplicates. The three functions below are created for this purpose. First, let’s look at the is_primary_key function. As its name indicates, this function checks if the specified column(s) forms a primary key in the DataFrame. A few things to note

  • It returns False when the dataframe is empty, or when any of the specified columns are missing from the dataframe.
  • It checks for missing values in any of the specified columns and excludes relevant rows from the row counts.
  • Using the verbose parameter, users can specify whether to print out or suppress detailed info during the function run.
def is_primary_key(df: DataFrame, cols: List[str], verbose: bool = True) -> bool:
Check if the combination of specified columns forms
a primary key in the DataFrame.

df (DataFrame): The DataFrame to check.
cols (list): A list of column names to check for forming a primary key.
verbose (bool): If True, print detailed information. Default is True.

bool: True if the combination of columns forms a primary key, False otherwise.

# Check if the DataFrame is not empty
if df.isEmpty():
if verbose:
print("DataFrame is empty.")
return False

# Check if all columns exist in the DataFrame
missing_cols = [col_name for col_name in cols if col_name not in df.columns]
if missing_cols:
if verbose:
print(f"Column(s) {', '.join(missing_cols)} do not exist in the DataFrame.")
return False

# Check for missing values in each specified column
for col_name in cols:
missing_rows_count = df.where(col(col_name).isNull()).count()
if missing_rows_count > 0:
if verbose:
print(f"There are {missing_rows_count:,} row(s) with missing values in column '{col_name}'.")

# Filter out rows with missing values in any of the specified columns
filtered_df = df.dropna(subset=cols)

# Check if the combination of columns is unique after filtering out missing value rows
unique_row_count = filtered_df.select(*cols).distinct().count()
total_row_count = filtered_df.count()

if verbose:
print(f"Total row count after filtering out missings: {total_row_count:,}")
print(f"Unique row count after filtering out missings: {unique_row_count:,}")

if unique_row_count == total_row_count:
if verbose:
print(f"The column(s) {', '.join(cols)} form a primary key.")
return True
if verbose:
print(f"The column(s) {', '.join(cols)} do not form a primary key.")
return False
Demo output of the sut.is_primary_key() function

Find duplicates

Consistent with our inspection of the dummy table, the two ID fields do not form a primary key. Of course, duplicates can exist in real data too, below is the function to identify them. 🔎

def find_duplicates(df: DataFrame, cols: List[str]) -> DataFrame:
Function to find duplicate rows based on specified columns.

- df (DataFrame): The DataFrame to check.
- cols (list): List of column names to check for duplicates

- duplicates (DataFrame): PySpark DataFrame containing duplicate rows based on the specified columns,
with the specified columns and the 'count' column as the first columns,
along with the rest of the columns from the original DataFrame

# Filter out rows with missing values in any of the specified columns
for col_name in cols:
df = df.filter(col(col_name).isNotNull())

# Group by the specified columns and count the occurrences
dup_counts = df.groupBy(*cols).count()

# Filter to retain only the rows with count greater than 1
duplicates = dup_counts.filter(col("count") > 1)

# Join with the original DataFrame to include all columns
duplicates = duplicates.join(df, cols, "inner")

# Reorder columns with 'count' as the first column
duplicate_cols = ['count'] + cols
duplicates = duplicates.select(*duplicate_cols, *[c for c in df.columns if c not in cols])

return duplicates
Demo output of the sut.find_duplicates() function

Columns responsible for the duplicates

From the above table, it is fairly easy to tell which columns are responsible for duplications in our data.

  • 🔎 The City column is responsible for the differences in 101-A and 102-B ID combinations. For example, the dup in 101-A is because the City is recorded both as “New York” and “NY”.
  • 🔎 The Name column is responsible for the difference in the 105-B ID combination, where the person’s name is “Bingo” in one record and “Binggy” in another.

Identifying the root cause of the dups is important for troubleshooting. For instance, based on the discovery above, we should consolidate both city and person names in our data.

You can imagine that when we have very wide tables and many more dups, identifying and summarizing the root cause using human eyes 👀 like we did above becomes much trickier.

The cols_responsible_for_id_dups function comes in rescue by summarizing the difference_counts for each column based on the primary key(s) provided. 😎 For example, in the output below, we can easily see that the field City is responsible for differences in two unique ID combinations, while the Name column is responsible for the dups in one ID pair.

def cols_responsible_for_id_dups(spark_df: DataFrame, id_list: List[str]) -> DataFrame:

This diagnostic function checks each column
for each unique id combinations to see whether there are differences.
This can be used to identify columns responsible for most duplicates
and help with troubleshooting.

- spark_df (DataFrame): The Spark DataFrame to analyze.
- id_list (list): A list of column names representing the ID columns.

- summary_table (DataFrame): A Spark DataFrame containing two columns
'col_name' and 'difference_counts'.
It represents the count of differing values for each column
across all unique ID column combinations.

# Get or create the SparkSession
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

# Filter out rows with missing values in any of the ID columns
filtered_df = spark_df.na.drop(subset=id_list)

# Define a function to count differences within a column for unique id_list combinations
def count_differences(col_name):
Counts the number of differing values for each col_name.

- col_name (str): The name of the column to analyze.

- count (int): The count of differing values.

# Count the number of distinct values for each combination of ID columns and current column
distinct_count = filtered_df.groupBy(*id_list, col_name).count().groupBy(*id_list).count()
return distinct_count.filter(col("count") > 1).count()

# Get the column names excluding the ID columns
value_cols = [col_name for col_name in spark_df.columns if col_name not in id_list]

# Create a DataFrame to store the summary table
summary_data = [(col_name, count_differences(col_name)) for col_name in value_cols]
summary_table = spark.createDataFrame(summary_data, ["col_name", "difference_counts"])

# Sort the summary_table by "difference_counts" from large to small
summary_table = summary_table.orderBy(col("difference_counts").desc())

return summary_table
Demo output of the sut.cols_responsible_for_id_dups() function

The columns responsible for the most duplicates are listed at the top of the summary table. We can then analyze these columns further for troubleshooting. If you have a very wide table, narrowing down the investigation like this can be pretty handy.

Value counts with percentage

For example, you can zoom in by checking relevant columns’ value counts among the dups. And of course, I have a custom function ready for you to do just this. 😜 This function is very much like the value_counts in pandas, with two additional features

  • percentage for each unique value
  • comma-formated numbers in the printout

Let’s see it in action

def value_counts_with_pct(df:DataFrame, column_name:str) -> DataFrame:
Calculate the count and percentage of occurrences for each unique value
in the specified column.

- df (DataFrame): The DataFrame containing the data.
- column_name (str): The name of the column for which to calculate value counts.

- DataFrame: A DataFrame containing two columns: the unique values in the specified column and their corresponding count and percentage.

counts = df.groupBy(column_name).agg(
(count("*") / df.count() * 100).alias("pct")

counts = counts.withColumn("pct", round(col("pct"), 2))

counts = counts.orderBy(col("count").desc())

# Format count column with comma spacing for printing
formatted_counts = counts.withColumn("count", format_string("%,d", col("count")))

# Return counts DataFrame with raw numbers
return counts
Demo output of the sut.value_counts_with_pct() function

So here you go, above are some common functions we can use to streamline and speed up data preparation for ML projects with pySpark. Let me know if you like them, I have many more to share if you find them helpful. 😎

Image generated by Gemini

Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.

Published via Towards AI

Feedback ↓