So far the background of PySpark structured streaming and the motivation for using applyInPandasWithState along with a notebook to generate streaming files has been covered. In part 3 of this tutorial on how to use applyInPandasWithState, the CSV files will be streamed, data will be grouped by flight id and custom logic to maintain the state will be implemented. This demo will also introduce test functions although here just one test function will be explored. The video accompanying this tutorial will go into more detail on how to develop custom streaming for you own purposes and cover greater depth on elements of the code.
Article pages
Download PySpark streaming example
You can download the notebook for this article and adapt it for your own purposes below.
To write the test functions it is necessary to install pytest and ipytest. This is therefor the first thing done below.
%pip install pytest
%pip install ipytest
Imports
Here the standard modules are imported. The items of significance include the GroupState and GroupStateTimeout. These are used to in state management to store the group state and to timeout the state and remove it if there has been no activity for the designated period.
# Standard library imports
import time
from datetime import datetime, timedelta
# pandas imports
import pandas as pd
from pandas import Timestamp
# pytest imports
import pytest
import ipytest
# spark imports
from pyspark.sql.types import StructType, StructField, LongType, StringType, BooleanType, IntegerType, TimestampType, ArrayType, Tuple, Iterator
from pyspark.sql import functions as F
from pyspark.sql.streaming.state import GroupState, GroupStateTimeout
from delta.tables import *
ipytest.autoconfig()
Optimisation settings
In the implementation on which this demo is based, 1 billion records per day where ingested. To help meet performance requirements, adaptive query execution to re-optimize query plans during runtime is turned on. To help process larger volumes of active states the state store is set to rocks DB for 100 times more efficient key storage.
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.streaming.stateStore.providerClass","com.databricks.sql.streaming.state.RocksDBStateStoreProvider")
Global Variables
To run this notebook in your own environment you need to set these variables to your own values. This demo stores all the data in the unity catalog and volumes. You must have a unity catalog and volumes set up. Set the catalog name, schema name and volume name to where you wrote the streaming files to in the previous part of this article.
You will also need to set your storge account location and folders you created in the previous step for the vol_data_dump, vol_data_stream, and vol_checkpoint_location.
# This notebook uses unity catalog to store the data.
# Change the following variables to point to your catalog, schema and volume
catalog_name = "pads-catalog"
schema_name = "demo-schema"
volume_name = "landing-zone"
delta_table_name = "flight_data_analysis"
# Change padsunitycatalogue to your storage account name.
# You can also specify a container name and directory location for the catalog e.g. : 'abfss://containername@
Define Schemas
- data_iot_schema: Specifies the structure of the csv file that is streamed from the source folder.
- state_schema: This is the schema of the group state. It is a PySpark struct type that is converted into a tuple. In this case it is an array of dictionary objects defined by the “data_iot_schema” schema. This enables the entire data set for the group to be stored in one column in the tuple. This was needed in the project from which this is based so that AI could analyse the entire data set and calculate a result to be stored in the group state.
- output_schema: This is the schema of the data that is returned and subsequently written to an output table.
# This is the schema of the CSV files and the data stored in the tuple state
data_iot_schema = StructType([
StructField("flight_id", IntegerType(), True),
StructField("aircraft_id", IntegerType(), True),
StructField("left_engine_oil_pressure", IntegerType(), True),
StructField("right_engine_oil_pressure", IntegerType(), True),
StructField("left_engine_fuel_flow_rate", IntegerType(), True),
StructField("right_engine_fuel_flow_rate", IntegerType(), True),
StructField("left_engine_vibration", IntegerType(), True),
StructField("right_engine_vibration", IntegerType(), True),
StructField("date_time", TimestampType(), True),
StructField("landed", IntegerType(), True)
])
# The state schema specifies an array of dictionaries containing the IoT data points for a particular group
state_schema = StructType([
StructField("IoT_readings", ArrayType(data_iot_schema), True)
])
# This is the schema of the output data frame that is to be written to a delta table
output_schema = StructType([
StructField("flight_id", IntegerType(), True),
StructField("IoT_readings", ArrayType(data_iot_schema), True),
StructField("left_engine_vibration_pass", BooleanType(), True),
StructField("right_engine_vibration_pass", BooleanType(), True),
StructField("state_status", StringType(), True)
])
Convert pandas data frame to list
Rather than writing all of the code in one place, helper functions are used. The applyInPandasWithState function passes a pandas dataframe for each group of data. The state however is stored as a tuple. The convert_pd_df_to_list function converts the records in a pandas dataframe to a list. This can then be stored in the tuple.
def convert_pd_df_to_list(df_for_group: pd.DataFrame):
###########################################################
# This function loops through each row in the pd data frame and
# adds the row as a dictionary into a list.
# List used to contain all readings for a particular group
list_IoT_readings = []
for index, row in df_for_group.iterrows():
list_IoT_readings.append(row.to_dict())
return list_IoT_readings
Analyse group data points
This is the function that you should customise for your own requirements. It receives the group data stored in the state, applies some business logic and then returns the results. The returned results of this function is subsequently written to the output delta table.
For this demo, the function accepts a list of all data recorded for a particular flight. The function then checks to see if any engine vibration values are higher than an acceptable limit. If they are above a limit then a flag is returned. It also checks to see if the flight has now landed and returns that result also.
From the project on which this example is based, this function applied a series of AI algorithms and other business logic. The results of this where returned and written to the output delta table. Immediate pre-emptive action could then be taken based on the results.
def agg_analyse_data(combined_list: list, threshold:int = 95):
# This function inspects all the data points for a data group and
# determines if the threshold has been exceeded. It then returns the updated state of the parameter
# This function could also be used to apply an AI algorithm to a set of IoT data points and returns a result.
left_engine_vibration_pass = True
right_engine_vibration_pass = True
has_landed = False
# Derive IoT state from complete list of data points
for element in combined_list:
# Process each item
left_engine_vibration_value = element["left_engine_vibration"]
right_engine_vibration_value = element["right_engine_vibration"]
landed = element["landed"]
if(left_engine_vibration_value >= threshold):
left_engine_vibration_pass = False
if(right_engine_vibration_value >= threshold):
right_engine_vibration_pass = False
if(landed == 1):
has_landed = True
# break out of loop if both thresholds are false
if(left_engine_vibration_pass == False and right_engine_vibration_pass == False and has_landed == True):
break;
return (left_engine_vibration_pass, right_engine_vibration_pass, has_landed)
Update state function
The update_state function manages what data is stored in the streaming group state. This function will broadly follow the same structure for all implementations. What changes is what you want to store in the group state and what dataset you return to be written to the output table. If you find this hard to follow then try watching the video where a little more time will be taken to describe it.
It follows the following structure:
- In line 11 – 33 it checks to see if the state has timed out. This is done encase an issue occurs with the data and the state is held indefinitely or if the state should be held for a particular time window. If the state has timed out then pull down the data in the state, return the data to be written to the output table and then remove the state.
- In line 43 – 63 If the state has not timed out then it gets any new data that has been streamed and checks to see if a state already exists or not. If the state already exists then the existing data is retrieved from the state and combined with the new batch data. This is converted into a tuple and the new state is updated.
- In line 68 the new state is then analysed and the results are returned to 3 variables such as left_engine_vibration_pass. These results are combined with other data in lines 70 -75. df_output is then the data frame that is returned. On line 78 – 80 the state is removed if the flight has now landed.
- In line 82 – 107, if the state does not already exist then the same process is applied to the new streamed data.
- In line 110 – 115, every time the function is called the timeout period is reset with state.setTimeoutDuration(timeout_duration_ms). This means that if no activity is seen for that period then the group state will timeout
- Any exceptions that occur are court by a try catch statement and the error message is returned in the data set for ease of debugging and later retrieval.
def update_state(key:Tuple[int], pdfs:Iterator[pd.DataFrame], state:GroupState)->Iterator[pd.DataFrame]:
# Initialise variables
str_state_status = "state_status: success" # Status is returned in output
(int_id,) = key # Used to get the id from the key
list_IoT_readings = [] # Used to contain all readings for a particular group
tuple_state = () # Used to store individual IoT readings in a state
df_output = pd.DataFrame() # initialise dataframe to return to output
try:
if state.hasTimedOut:
####################################################
# State has timed out so return the final state and clear it from memory
# Get final tuple stored in the state
(list_prev_IoT_readings,) = state.get
####################################################
# Construct final id, IoT readings, IoT state, and status into a dataframe to return to output
# Get final state from state tuple
(left_engine_vibration_pass, right_engine_vibration_pass, has_landed) = agg_analyse_data(list_prev_IoT_readings, threshold=threshold)
# Set status to include any error information in group state
str_state_status = "state_status: state group timed out"
# Create dataframe to yeild with flight id, flight readings and flight
dict_data = {'flight_id': [id], 'IoT_readings': [list_prev_IoT_readings], 'left_engine_vibration_pass': [left_engine_vibration_pass], 'right_engine_vibration_pass': [right_engine_vibration_pass], 'state_status': [str_state_status]}
df_output = pd.DataFrame(dict_data)
# Remove state
state.remove()
else:
####################################################
# State has not timed out so proccess new data
#####################################
# Get the new data
# Loop through each pandas dataframe in the iterator. Each dataframe contains data for a group
for df_for_group in pdfs:
# Get a list of all data for a particular group
list_IoT_readings = convert_pd_df_to_list(df_for_group)
if state.exists:
####################################################
# State already exists so append the new data
# to the existing data
# Get the list of data stored in the current tuple state
(list_prev_IoT_readings,) = state.get
# Append the new data list to the existing list of data points
combined_IoT_readings = list_prev_IoT_readings + list_IoT_readings
# place IoT readings list into a tuple
tuple_state = (combined_IoT_readings,)
# Update the current state
state.update(tuple_state)
####################################################
# Construct IoT readings state and dataframe to return to output
(left_engine_vibration_pass, right_engine_vibration_pass, has_landed) = agg_analyse_data(combined_IoT_readings, threshold=95)
str_state_status = "state_status: state group updated"
# Create dataframe to yeild with flight id, flight readings and flight
dict_data = {'flight_id': [int_id], 'IoT_readings': [combined_IoT_readings], 'left_engine_vibration_pass': [left_engine_vibration_pass], 'right_engine_vibration_pass': [right_engine_vibration_pass], 'state_status': [str_state_status]}
df_output = pd.DataFrame(dict_data)
# Stop holding group data in state once the flight has landed
if(has_landed == True):
# Remove state
state.remove()
else:
####################################################
# State does not exist yet so initialise it.
# Place IoT readings list into a tuple
tuple_state = (list_IoT_readings,)
# Update the state with the initialised tuple
state.update(tuple_state)
####################################################
# Construct IoT readings state and dataframe to return to output
(left_engine_vibration_pass, right_engine_vibration_pass, has_landed) = agg_analyse_data(list_IoT_readings, threshold=threshold)
str_state_status = "state_status: state group initiated"
# Create dataframe to yeild with flight id, flight readings and flight
dict_data = {'flight_id': [int_id], 'IoT_readings': [list_IoT_readings], 'left_engine_vibration_pass': [left_engine_vibration_pass], 'right_engine_vibration_pass': [right_engine_vibration_pass], 'state_status': [str_state_status]}
df_output = pd.DataFrame(dict_data)
# Stop holding group data in state once the flight has landed
if(has_landed == True):
# Remove state
state.remove()
####################################################
# Reset timeout period for group state
# Assume maximum duration of flight is 3 hours so timeout after this
#timeout_duration_ms = 1000 * 60 * 60 * 0.005
#state.setTimeoutDuration(timeout_duration_ms)
except Exception as e:
str_state_status = f"state_status: Error occured - {e}"
(list_prev_IoT_readings,) = state.get
############################################
# return error data
dict_data = {'flight_id': [int_id], 'IoT_readings': [list_prev_IoT_readings], 'left_engine_vibration_pass': [None], 'right_engine_vibration_pass': [None], 'state_status': [str_state_status]}
df_output = pd.DataFrame(dict_data)
# Yield dataframe back to calling function
yield df_output
Read data from stream
The prerequisite work to manage the group state and define what data is returned has now been completed. Now it is time to make use of it with the standard streaming syntax. Below the csv files are streamed using the data_iot_schema.
flight_data_stream_df = spark \
.readStream \
.option("sep", ",") \
.option("header", "true") \
.schema(data_iot_schema) \
.csv(vol_data_stream)
applinInPandasWithState
The flight_data_stream_df is now grouped by flight_id. This will define the scope of data that will be stored in the group state. The custom state management is then implemented by the applyInPandasWithState function. It accepts several parameters including the name of the function that is going to manage the group state, the schema of the data that is returned, the schema of the group state, what mode the streaming data should be written in and then finally if the group should timeout or not.
# Apply the stateful processing by updateing the
# state every time a new data set is received
df_result = flight_data_stream_df.groupBy(F.col("flight_id")) \
.applyInPandasWithState( func=update_state,
outputStructType=output_schema,
stateStructType=state_schema,
outputMode="append",
timeoutConf=GroupStateTimeout.NoTimeout
)
Create and optimise delta table
A delta table to which the streamed data can be written is created below. The columns of the table are specified by adding the output_schema. Liquid clustering is applied on line 7 and the table is then optimised based on the liquid clustering.
# Create the Delta table if it does not exist
DeltaTable.createIfNotExists(spark) \
.tableName(f"`{catalog_name}`.`{schema_name}`.`{delta_table_name}`") \
.addColumns(output_schema) \
.execute()
spark.sql(f"ALTER TABLE `{catalog_name}`.`{schema_name}`.`{delta_table_name}` CLUSTER BY (flight_id)")
deltaTable = DeltaTable.forName(spark, f"`{catalog_name}`.`{schema_name}`.`{delta_table_name}`")
deltaTable.optimize().executeCompaction()
Upsert delta table
To write the streamed data to a delta table the writeStream method is used. The foreachBatch method is however needed to upsert data into the detla table as append mode only writes new records to the delta table. By using foreachBatch, new data is inserted and exiting data is updated. The foreachBatch method calls the batch_upsert function to run a merge statement with the batch of streamed data. Using update mode is not an option as it cannot update delta tables.
delta_table = DeltaTable.forName(spark, f"`{catalog_name}`.`{schema_name}`.`{delta_table_name}`")
# Function to upsert microBatchOutputDF into Delta table using merge
def batch_upsert(micro_batch_output_df, batch_id):
delta_table.alias("t").merge( micro_batch_output_df.alias("b"),
"b.flight_id = t.flight_id"
) \
.whenMatchedUpdateAll() \
.whenNotMatchedInsertAll() \
.execute()
# Write the output of a streaming aggregation query into Delta table
(
df_result .writeStream
.foreachBatch(batch_upsert)
.outputMode("append")
.option("checkpointLocation", vol_checkpoint_location) \
.start()
)
Test notebook
Each element of the notebook should be tested. Here ipytest is used rather than pytest as the code is sat within a notebook rather than a python file. This test simply tests the functionality of the convert_pd_df_to_list function and checks that the output matches an expected result. There are more test functions in the downloadable notebook at the top of this page and the main thing you will want to use this for is testing the update_state function. In the video the approach to how to develop the update_state function is discussed. It can be difficult as to test the update_state function since it requires specific parameters to by passed such as a GroupState.
%%ipytest
def test_convert_pd_df_to_list():
test_df = pd.DataFrame([{'flight_id': 20, 'aircraft_id': 0, 'left_engine_oil_pressure': 42, 'right_engine_oil_pressure': 52, 'left_engine_fuel_flow_rate': 69, 'right_engine_fuel_flow_rate': 50, 'left_engine_vibration': 21, 'right_engine_vibration': 72, 'date_time': Timestamp('2025-03-27 17:52:03.397000'), 'landed': 0}, {'flight_id': 21, 'aircraft_id': 1, 'left_engine_oil_pressure': 79, 'right_engine_oil_pressure': 56, 'left_engine_fuel_flow_rate': 62, 'right_engine_fuel_flow_rate': 20, 'left_engine_vibration': 29, 'right_engine_vibration': 43, 'date_time': Timestamp('2025-03-27 17:52:03.398000'), 'landed': 0}], index=['a', 'b'])
expected_result = [{'flight_id': 20, 'aircraft_id': 0, 'left_engine_oil_pressure': 42, 'right_engine_oil_pressure': 52, 'left_engine_fuel_flow_rate': 69, 'right_engine_fuel_flow_rate': 50, 'left_engine_vibration': 21, 'right_engine_vibration': 72, 'date_time': Timestamp('2025-03-27 17:52:03.397000'), 'landed': 0}, {'flight_id': 21, 'aircraft_id': 1, 'left_engine_oil_pressure': 79, 'right_engine_oil_pressure': 56, 'left_engine_fuel_flow_rate': 62, 'right_engine_fuel_flow_rate': 20, 'left_engine_vibration': 29, 'right_engine_vibration': 43, 'date_time': Timestamp('2025-03-27 17:52:03.398000'), 'landed': 0}]
actual_result = convert_pd_df_to_list(test_df)
assert actual_result == expected_result