Scikit-Learn - Profiling transformers with line_profiler

Posted on Apr 23, 2022

It’s important to be able to profile transformers. I usually use line_profiler and a Jupyter Notebook.

Looking at ‘PlusOne’ transformer:

%load_ext line_profiler

import copy
import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin

class PlusOne(BaseEstimator, TransformerMixin):
    def __init__(self, columns):
    self.columns = columns
        def fit(self, X, y=None):
            return self
        def transform(self, X: pd.DataFrame) -> pd.DataFrame:
            X = copy.copy(X)
            if isinstance(X, pd.DataFrame):
                X[self.columns] += 1
            elif isinstance(X, dict):
                for column in self.columns:
                    X[column] += 1
            return X
row = {"a": 1}

666 ns ± 8.3 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

Looking at %%timeit we can see that it runs around 666ns, which is already very fast, but using line_profiler:

%lprun -f plus_one.transform plus_one.transform(row)
Timer unit: 1e-06 s

Total time: 9e-06 s
File: /tmp/ipykernel_3915658/
Function: transform at line 13

Line #      Hits         Time  Per Hit   % Time  Line Contents
    13                                               def transform(self, X: pd.DataFrame) -> pd.DataFrame:
    14         1          5.0      5.0     55.6          X = copy.copy(X)
    16         1          1.0      1.0     11.1          if isinstance(X, pd.DataFrame):
    17                                                       X[self.columns] += 1
    18         1          0.0      0.0      0.0          elif isinstance(X, dict):
    19         2          1.0      0.5     11.1              for column in self.columns:
    20         1          1.0      1.0     11.1                  X[column] += 1
    22         1          1.0      1.0     11.1          return X

We can see that the copy function takes 55% of the time. The shallow copy is used to prevent overwriting the original input. Although in most cases especially production you won’t actually need it, or just shallow copy your input beforehand.

The type dict also has a copy function, which is faster than the general copy.copy implementation. Also since in production we always use dictionaries, so we have to do two comparisons, if we swap them around we might get more throughput.

class PlusOne(BaseEstimator, TransformerMixin):
    def __init__(self, columns):
        self.columns = columns

    def fit(self, X, y=None):
        return self

    def transform(self, X: pd.DataFrame) -> pd.DataFrame:
        if isinstance(X, dict):
            X = X.copy()
            for column in self.columns:
                X[column] += 1
        if isinstance(X, pd.DataFrame):
            X = copy.copy(X)
            X[self.columns] += 1
        return X

430 ns ± 9.63 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

We improved the pipeline by ~35%, and looking at the line_profiler:

Timer unit: 1e-06 s

Total time: 1e-05 s
File: /tmp/ipykernel_3915658/
Function: transform at line 8

Line #      Hits         Time  Per Hit   % Time  Line Contents
     8                                               def transform(self, X: pd.DataFrame) -> pd.DataFrame:
     9         1          2.0      2.0     20.0          if isinstance(X, dict):
    10         1          2.0      2.0     20.0              X = X.copy()
    11         2          2.0      1.0     20.0              for column in self.columns:
    12         1          1.0      1.0     10.0                  X[column] += 1
    13         1          3.0      3.0     30.0          if isinstance(X, pd.DataFrame):
    14                                                       X = copy.copy(X)
    15                                                       X[self.columns] += 1
    16         1          0.0      0.0      0.0          return X

We can further improve performance if we don’t use copy or having another function only for dictionaries, and we might be able to cut the time by another 40%.

However in most application, model throughput is much higher compared to the rest of the pipeline, but if the pipeline has throughput issues, then using line_profiler is a quick and easy way to discover them.