Scikit-Learn - Custom Transformers

Posted on Apr 4, 2022

Scikit-learn provides plenty of transformers, such as StandardScaler, but sometimes you want to implement your own transformers, for example selecting columns, or adding specific values etc.

A sample transformer is as follows, where it will add the columns specified +1:

import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin

import copy

class PlusOne(BaseEstimator, TransformerMixin):
    def __init__(self, columns):
        self.columns = columns

    def fit(self, X, y=None):
        return self

    def transform(self, X: pd.DataFrame) -> pd.DataFrame:
        return X[self.columns] + 1
sample = pd.DataFrame([{"a": 1}])
PlusOne(columns="a").fit_transform(sample)

All transformers should inherit BaseEstimator, and if the module is a Transformer, then it should also inherit TransformMixin.

You can use the class initialization to pass in variables, such as self.columns which are the columns that we want to perform the subsequent function. And we can update the class variables through the fit function.

A sample input of:

a
0 1

Will be transformed to:

a
0 2

The fit function would always be called before the transform function. Another example is we can use the fit function to remember an average, and then add this to the original dataframe, such as:

class PlusMean(BaseEstimator, TransformerMixin):
    def __init__(self, columns):
        self.columns = columns
        self.mean = None

    def fit(self, X, y=None):
        self.mean = X[self.columns].mean()

        return self

    def transform(self, X: pd.DataFrame) -> pd.DataFrame:
        X = copy.copy(X)

        return X[self.columns] + self.mean
sample = pd.DataFrame([{"a": 1}, {"a": 3}])
PlusMean(columns="a").fit_transform(sample)

A sample input of:

a
0 1
1 5

Will be transformed to:

a
0 3
1 5

Through custom transformers this allows us to create logic to stack inside a Pipeline, and reduce inference issues.