27 Jan 2022

Scikit-learn fit() and transform() methods

  • The transform method is used to apply a series of transformations to a dataset
  • The fit method is used to learn the parameters or statistics (means, variances, scaling factors, etc.) of a data transformation based on the input data. The specific behavior and purpose of fit may vary depending on the estimator or transformer being used.

Example 1: Feature Scaling using StandardScaler

from sklearn.preprocessing import StandardScaler

# Create a StandardScaler object
scaler = StandardScaler()

# Fit the scaler on the training data
scaler.fit(X_train)

# Apply the learned scaling to transform the training data
X_train_scaled = scaler.transform(X_train)

# Apply the same scaling to transform the test data
X_test_scaled = scaler.transform(X_test)

In this example, fit is used to learn the mean and standard deviation from the training data (X_train). Then, the transform method is used to apply the same scaling to both the training data and the test data (X_test).

Example 2: Dimensionality Reduction using PCA

from sklearn.decomposition import PCA

# Create a PCA object
pca = PCA(n_components=2)

# Fit the PCA model on the training data and transform it
X_train_pca = pca.fit_transform(X_train)

# Apply the learned PCA transformation to the test data
X_test_pca = pca.transform(X_test)

In this example, the fit method is used to compute the mean and standard deviation of the data (X) using the StandardScaler transformer. These statistics are necessary for standardizing the data by subtracting the mean and dividing by the standard deviation. After calling fit, the scaler object contains the computed statistics.

Example 3: Learning Class Labels for LabelEncoder

from sklearn.preprocessing import LabelEncoder

# Create a LabelEncoder object
encoder = LabelEncoder()

# Fit the encoder on the target variable to learn the class labels
encoder.fit(y)

In this example, the fit method is used to learn the class labels from the target variable (y) using the LabelEncoder transformer. The encoder analyzes the target variable and assigns a unique integer label to each class. After calling fit, the encoder object contains the learned class labels.