pyspark如何实现树层次结构的深度遍历

pyspark如何实现树层次结构的深度遍历。在表中有两列,表示父子结构,如果用pyspark实现树结构,建立层次关系,并且实现某节点所有子孙节点的遍历访问。
比如:下面三行数据,表示节点id:0有两个子节点,一个孙节点。
sonID paraID
1 0
2 0
3 1


from pyspark.sql import SparkSession

# 创建 SparkSession
spark = SparkSession.builder \
    .appName("Tree Traversal") \
    .getOrCreate()

# 创建 DataFrame 存储父子关系
data = [(1, 0), (2, 0), (3, 1)]
columns = ['sonID', 'paraID']
df = spark.createDataFrame(data, columns)

# 创建递归函数来实现深度遍历
def traverse_tree(node_id):
    # 获取当前节点的所有子节点
    children = df.filter(df['paraID'] == node_id).select('sonID').rdd.flatMap(lambda x: x).collect()
    
    # 递归遍历子节点的子节点
    for child in children:
        print("Node:", child)
        traverse_tree(child)

# 设置根节点
root_node = 0

# 开始深度遍历
print("Root Node:", root_node)
traverse_tree(root_node)

你可以参考一下这个