Saturday, October 07, 2023

A PySpark idiom for efficient Model Inference

I recently needed to build an Apache Spark (PySpark) job where the task was (among other things) to use a Language Model (LM) to encode text into vectors. This is an embarassingly parallel job where the text to encoding is one to one, so something like Spark works very well here. We could, in theory at least, achieve a N-fold performance improvement by horizontally partitioning the data into N splits respectively, and encoding them using N parallel workers.

However, LMs (and Machine Learning (ML) models in general) usually take some time to initialize before it is ready for use. This initialization step loads the model's parameters (multi-dimensional tensors of weights learned during the training process) into memory. So it is not really feasible to do something like this:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
@dataclass
class Document:
    content: str
    metadata: Dict[str, Any]
    embedding: numpy.ndarray
 
def encode_row(row: Row) -> Row:
    model = initialize_model()
    row.embedding = model.encode(row.content)
    return row
    
data_rdd = data_rdd.map(lambda row: encode_row(row))

This is because it would require the model to be initialized for each row in our RDD, which can be very time-consuming. We can address this by initializing it on the master and broadcasting to all the workers, something I have done in the past.

1
2
3
4
5
6
7
8
def encode_row(row: Row) -> Row:
    model = bc_model.value
    row.embedding = model.encode(row.content)
    return row

model = initialize_model()
bc_model = sc.broadcast(model)
data_rdd = data_rdd.map(lambda row: encode_row(row))

But Spark provides a higher-order function (HOF) specifically for this use case, called mapPartitions, which allows you to specify code to create some heavyweight object(s) per partition, and then apply some processing (using these heavyweight objects) to all rows in the partition. So using this idiom, our processing code would look like this. You could also broadcast the model from the master instead of initializing it each time in the workers, which will save you the initialization time on each worker. Regardless, you can think of model.initialize_model as a wrapper for either approach.

1
2
3
4
5
6
7
def encode_rows(rows: Iterable[Row]) -> Row:
    model = initialize_model()
    for row in rows:
        row.embedding = model.encode(row.content)
        yield row

data_rdd = data_rdd.mapPartitions(lambda p: encode_rows(p))

However, LMs (and ML models in general) are designed to process input in batches. Generally inference (at least for neural models) involves a lot of matrix multiplications, which the underlying tensor library does in parallel if you feed your model in batches (or larger sets) rather than one input record at a time. Assuming the model was trained with batch size B (usually indicated by the default value for the batch_size parameter in the encode method (or equivalent)), this would translate roughly into a B-fold performance improvement if you fed it batches of size >= B. The model will internally partition the input into multiple batches of B records each, and process the batches sequentially and records within each batch in parallel.

So to allow the model to consume the rows in batches, we could change our code as follows.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
def encode_rows(rows: Iterable[Row]) -> Row:
    model = initialize_model()
    docs = [row for row in rows]
    texts = [doc.content for doc in docs]
    embeddings = model.encode(texts)
    for doc, embedding in zip(docs, embeddings):
        doc.embedding = embedding
        yield doc

data_rdd = data_rdd.mapPartitions(lambda p: encode_rows(p))

Obviously, the approach above assumes that you have enough memory per partition to hold the text for all the documents in the partition. If your texts in your partition is too large, you will get an Out of Memory (OOM) and the job will abort. So based on your data and your architecture, the simplest (and probably slightly brute force approach) is to repartition your RDD into a larger number of (smaller) partitions, where the texts will fit in memory. So maybe something like this...

1
2
k = calculate_optimum_partition_size()  # either dynamically or offline
data_rdd = data_rdd.repartition(k).mapPartitions(lambda p: encode_rows(p))

But this can lead to many small partitions, which may be an overhead for Spark since it now has to manage the additional coordination. Also assuming your were initializing the model in the mapPartitions call, the job would spend more time doing this as well if there were many small partitions. Another way (and basically the idiom I am trying to build up to in this blog post) could be to leave the partition intact and use itertools.islice to batch up rows within each partition using code instead of leveraging the side effect of the partition size. Something like this:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
def encode_rows(rows: Iterable[Row]) -> Row:
    model = initialize_model()
    start = 0
    while True:
        end = start + batch_size
        batch = itertools.islice(rows, start, end)
        docs = [row for row in batch]
        if len(docs) == 0:
            break
        texts = [doc.content for doc in docs]
        embeddings = model.encode(texts)
        start = end
        for doc, embedding in zip(docs, embeddings):
            doc.embedding = embedding
            yield doc

data_rdd = data_rdd.mapPartitions(lambda p: encode_rows(p))

EDIT 2023-12-11: -- I found a problem with this approach that took me a while to solve, so sharing it here in case it is helpful to someone down the line. I noticed that when applying the mapPartitions in the previous code block, the number of output records would often be smaller than the number of input records, i.e., the process lost records. I found I could mitigate it if I re-partitioned the RDD so that each partition contained number of records that were less than my batch size, i.e. itertools.islice is called only once. It turns out that islice messes up the underlying iterator (I did test its behavior with integer elements, but perhaps it behaves differently with non-primitive elements). The fix is to add a `rows, rows_copy = itertools.tee(rows)` between line 5 and 6 and only operate on the `rows_copy` in the islice call on line 6.

I am curious what people think of this approach? Using Spark to run ML inference at scale cannot be a new problem, but I wasn't able to find any information or best practices about this on the Internet. I did consider the possiblity that perhaps my Google-fu may not be as strong as I think, so I also tried Bard, and it didn't give me much to go on either. I am sure many Data Engineers before me have looked at this problem and have their own favorite solutions. Please share in the comments if you can!