Scikit-Learn - Pipeline - Workarounds using transform when model is not the last step.

Posted on May 30, 2022

One issue with Scikit-learn’s Pipeline in production is that it requires an estimator to be placed at the final stage. A lot of times however you’ll need to do extra logic after the final step, for example normalization of the output values, or perform some arithmetic.

def predict(self, X, **predict_params):
    Xt = X
    for _, name, transform in self._iter(with_final=False):
        Xt = transform.transform(Xt)
    return self.steps[-1][1].predict(Xt, **predict_params)

In order to use the function, it’s required that the final estimator is a model that has a predict function, and we’ll only call the last step’s predict function due to self.steps[-1][1].predict(Xt, **predict_params).

One approach is to wrap around the model into a transformer step, and the other approach is to update the Pipeline class. This post will be about wrapping a model.

An example pipeline without using Pipelines

from __future__ import annotations

import pandas as pd
import numpy as np
import copy
from typing import Union, List, Dict

from sklearn import datasets, linear_model, model_selection, preprocessing, metrics, base, pipeline, utils

iris: utils.bunch.Bunch = datasets.load_iris(as_frame=True)

df: pd.DataFrame = iris["data"]
target: pd.Series = iris["target"]

X_train: pd.DataFrame
X_test: pd.DataFrame
y_train: pd.Series
y_test: pd.Series
X_train, X_test, y_train, y_test = model_selection.train_test_split(

scaler: preprocessing.MinMaxScaler = preprocessing.MinMaxScaler()
X_train['sepal length (cm)'] = scaler.fit_transform(X_train['sepal length (cm)'].to_numpy().reshape(-1, 1))
X_test['sepal length (cm)'] = scaler.transform(X_test['sepal length (cm)'].to_numpy().reshape(-1, 1))

model: linear_model.LogisticRegression = linear_model.LogisticRegression().fit(X_train, y_train)
metrics.accuracy_score(y_test, model.predict(X_test))

To support dictionaries during prediction, we’ll need to create custom transformers for preprocessing.MinMaxScaler, linear_model.LogisticRegression and a conversion to a dictionary format later, since the model outputs numpy arrays.

class Scaler(base.BaseEstimator, base.TransformerMixin):
    def __init__(self, columns: List[str]):
        self.columns: List[str] = columns
        self.scalers: Dict[str, preprocessing.MinMaxScaler] = {column: preprocessing.MinMaxScaler() for column in self.columns}

    def fit(self, X: pd.DataFrame, y=None) -> Scaler:
        column: str
        for column in self.columns:
            self.scalers[column].fit(X[column].to_numpy().reshape(-1, 1))

        return self

    def transform(self, X: Union[pd.DataFrame, Dict[str, Union[float, int]]]) -> Union[pd.DataFrame, Dict[str, Union[float, int]]]:
        X = copy.copy(X)

        if isinstance(X, dict):
            for column in self.columns:
                X[column] = self.scalers[column].transform(np.array(X[column]).reshape(1, -1))
        elif isinstance(X, pd.DataFrame):
            for column in self.columns:
                X[column] = self.scalers[column].transform(X[column].to_numpy().reshape(-1, 1))

        return X

class EstimatorWrapper(base.BaseEstimator, base.TransformerMixin):
    def __init__(self, model: linear_model.LogisticRegression):
        self.model: linear_model.LogisticRegression = model

    def fit(self, X, y=None) -> EstimatorWrapper:
        return self

    def predict(self, X: Union[pd.DataFrame, Dict[str, Union[float, int]]]) -> np.ndarray:
        return self.transform(X)

    def transform(self, X: Union[pd.DataFrame, Dict[str, Union[float, int]]]) -> np.ndarray:
        X = copy.copy(X)

        if isinstance(X, dict):
            return self.model.predict(np.array(list(X.values())).reshape(1, -1))
        elif isinstance(X, pd.DataFrame):
            return self.model.predict(X)

class SetOutput(base.BaseEstimator, base.TransformerMixin):
    def __init__(self, name: str): str = name

    def fit(self, X: pd.DataFrame, y=None) -> SetOutput:
        return self

    def transform(self, X: np.ndarray) -> Union[pd.DataFrame, Dict[str, Union[float, int]]]:

        if X.shape[0] > 1:
            return pd.DataFrame({ X})

        elif X.shape[0] == 1:
            return { float(X[0])}

A note for the SetOutput function is to cast numpy on the output, otherwise it will be a numpy.float and this object is not pickle-able.

We can then create a pipeline for the transformers, although there’s only one transformers here, but if we had multiple transformers we can fit them first before going to the model.

pipe: pipeline.Pipeline = pipeline.make_pipeline(
            'sepal length (cm)'
X_train = pipe.fit_transform(X_train, y_train)

You could also fit them with the model, and include this with GridSearchCV, or without. I prefer to only create a pipeline with the transformers first, and then transform all the dataset, since GridSearchCV will also fit the pipeline, which makes the hyperparameter tuning phase slow.

Then we can create a model, and throw it inside our EstimatorWrapper:

model = linear_model.LogisticRegression(), y_train)
wrapped_estimator = EstimatorWrapper(

Then once we are done with any model tuning that we want, we can create another pipeline that stacks: 1. Preprocessing 2. Model 3. Post processing

model_with_pipe = pipeline.make_pipeline(
model_with_pipe.predict = model_with_pipe.transform

Then we can call the pipeline a predict since we are just calling the transform per each step by step, we can safely override the predict function.


or with a dictionary


If we used the model without the wrapper, it can only work as the last stage:

model_with_pipe = pipeline.make_pipeline(

And if we added any post processing, it would error out during prediction, i.e.:

model_with_pipe = pipeline.make_pipeline(
AttributeError                            Traceback (most recent call last)
/tmp/ipykernel_2821917/ in <cell line: 1>()
----> 1 model_with_pipe.predict(X_test)

~/.pyenv/versions/3.10.4/envs/3_10_4/lib/python3.10/site-packages/sklearn/utils/ in __get__(self, obj, owner)
    125             # delegate only on instances, not the classes.
    126             # this is to allow access to the docstrings.
--> 127             if not self.check(obj):
    128                 raise attr_err
    129             out = MethodType(self.fn, obj)

~/.pyenv/versions/3.10.4/envs/3_10_4/lib/python3.10/site-packages/sklearn/ in check(self)
     44     def check(self):
     45         # raise original `AttributeError` if `attr` does not exist
---> 46         getattr(self._final_estimator, attr)
     47         return True

AttributeError: 'SetOutput' object has no attribute 'predict'