How to train a linear classification model using scikit-learn

A linear classification model is a powerful tool in data science, aiming to categorize or classify data points into distinct classes based on their features. Using linear equations, these models separate data points by drawing straight lines (in 2D) or planes (in higher dimensions). This post will guide you through the steps to build one using scikit-learn. You can also follow along with this Colab notebook.

We'll be using the Iris plants dataset, which contains 3 classes with 50 instances each. Each class refers to a type of iris plant, and one class if linearly separable from the other two, which are not linearly separable from one another. There are four input variables: sepal length in cm, sepal width in cm, petal length in cm, and petal width in cm.

Setting Up and Loading Data

To begin, we must first import the necessary libraries and load our dataset:

from sklearn.datasets import load_iris

iris = load_iris()

Visualizing the Data with a 2D Chart

Often, visualizing data can offer insights that raw numbers or descriptions fail to provide. For our Iris dataset, let's plot a 2D chart of the first two features, i.e., sepal length and sepal width, and color each data point based on its class. This visualization will give us an idea of how these two features can differentiate between the Iris classes.

import matplotlib.pyplot as plt

X =[:, :2]
y =

colors = ["red", "green", "blue"]
plt.figure(figsize=(10, 6))
for i, (color, target_name) in enumerate(zip(colors, iris.target_names)):
      X[y == i, 0], X[y == i, 1], color=color, label=target_name, marker="x"
plt.title('2D Chart of Sepal Length vs. Sepal Width')
plt.xlabel('Sepal Length (cm)')
plt.ylabel('Sepal Width (cm)')
plt.legend(loc="upper right")

2D Chart of Sepal Length (cm) vs. Sepal Width (cm)

This scatter plot reveals how the sepal length and sepal width features relate to the different Iris classes. While it's evident that one class is quite distinct, the other two show some overlap, affirming that they're not linearly separable from one another.

Splitting Data and Standardizing Numerical Feature Values

First, we need to split our data into a dataset of training examples and a dataset of test examples. This split will later help us evaluate the performance of our model on unseen examples that we didn't use to train the model.

For the purposes of this tutorial, we're going to just use the sepal length and sepal width features:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split([:, 0:2],, test_size=0.2, random_state=42

Data standardization is essential for algorithms that rely on the scale or magnitude of features. It ensures all numerical features have a similar scale, which we will implement by subtracting the mean and dividing by the std of each feature. This ensures that each feature has a zero mean and unit variance:

from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

There are two things worth noting here:

  1. Fitting the transformation: Always fit the transformation using only the training data. This mimics the real-world application where new, unseen data doesn't influence our scaling. Remember, when predicting on fresh data, you'll first apply the same transformation based on the training data's statistics.
  2. Standardization benefits: Standardizing our input features ensures that all feature are on a similar scale, which can help both training and interpretability. If we don't scale our features, those with larger magnitudes can disproportionately influence the model, particularly for gradient based methods. Standardization also ensures that coefficient magnitudes more directly indicate their importance with respect to the predicted output. Without it, coefficients could be skewed by the varying scales of the features, making it difficult to compare their significance accurately.

Training The Linear Classifier via SGD

With our data prepped and ready to go, let's create and train the linear classifier:

from sklearn.linear_model import SGDClassifier

clf = SGDClassifier(random_state=42), y_train)

We are using Stochastic Gradient Descent (SGD) to train our model. In simple terms, SGD tweaks the model's linear coefficients iteratively by assessing the direction (gradient) that minimizes the prediction error (loss) for each training example. By default, we're using hinge loss, but there are a variety of different configuration options. For a deep dive into available parameters for the model we're using, refer to the SGDClassifier documentation.

Understanding Feature Contributions and Visualizing Decision Boundaries

Although most real-world problems are not linear, fitting a linear classification model to your data is a great first step when modeling your data because it is simple and easy to interpret. By examining the model's coefficients, we can better understand the weight or importance of each feature for each class:

import numpy as np
import pandas as pd

class_lines = pd.DataFrame(
    np.append(clf.coef_.T, [clf.intercept_], axis=0),
    iris.feature_names[:2] + ["intercept"],

#                       setosa  versicolor  virginica
# sepal length (cm) -24.464447   -1.488340   6.063765
# sepal width (cm)   12.907980   -3.411526  -3.036891
# intercept         -11.847741   -4.899218  -1.279970

You'll notice that there is a line for each class. This is because the model has learned a linear decision boundary for each class vs. the remaining classes. We can better understand exactly how this is working by visualizing the decision boundaries for each class once again using sepal length and width as our features for the visualization:

x_min, x_max = X_train[:, 0].min() - 0.1, X_train[:, 0].max() + 0.1
xx = np.linspace(x_min, x_max)

plt.figure(figsize=(10, 6))
for i, (coef, intercept, color, target_name) in enumerate(
    zip(clf.coef_, clf.intercept_, colors, iris.target_names)
    yy = -(coef[0] * xx + intercept) / coef[1]
    plt.plot(xx, yy, color=color)
        X_train[y_train == i, 0],
        X_train[y_train == i, 1],
plt.title('Decision Boundaries and Data Points')
plt.xlabel('Sepal Length (Standardized)')
plt.ylabel('Sepal Width (Standardized)')
plt.legend(loc="upper right")

2D Charts of Sepal Length (Standardized) and Sepal Width (Standardized) with Learned Decision Boundaries Overlayed

As expected, the classifier learned a decision boundary for setosa that perfectly separates the class. The other two decision boundaries have some points from the other class on the wrong side, which makes sense because the two classes are not linearly separable.

Generate Test Predictions and Measure Performance

Now that we have a trained model, we can use it to generate predictions for our test set. We can then evaluate the performance of our model by comparing our predictions to the actual labels. One of the simplest performance metrics is accuracy, which measures the proportion of correctly classified instances:

from sklearn.metrics import accuracy_score

y_pred = clf.predict(X_test)
print(f"Accuracy: {accuracy_score(y_test, y_pred):.2f}")

# Accuracy: 0.70

For classification tasks, the confusion matrix is another valuable tool for understanding model performance by breaking it down into the following four categories:

  1. True Positives (TP): The model correctly predicted the positive class.
  2. True Negatives (TN): The model correctly predicted the negative class.
  3. False Positives (FP): The model incorrectly predicted the positive class.
  4. False Negatives (FN): The model incorrectly predicted the negative class.

For multi-class classification problems like the Iris dataset, the matrix is extended to represent each class against the others:

from sklearn.metrics import confusion_matrix

cm = pd.DataFrame(
    confusion_matrix(y_test, y_pred),
    columns=[f"Predicted {tn}" for tn in iris.target_names]

#             Predicted setosa  Predicted versicolor  Predicted virginica
# setosa                    10                     0                    0
# versicolor                 0                     0                    9
# virginica                  0                     0                   11

K-Fold Cross-Validation

When we first split our dataset, we randomly sectioned off a portion of the data to use as test data; however, this means that our measured performance is dependent on this random split and not particularly stable. Cross-validation is a technique to more robustly assess the performance of the model by dividing the dataset into K different subsets (or "folds"). The model is trained on all but one of the K folds and tested on the remaining fold. This process is then repeated K times, each time with a different fold as the test set.

Here are two reasons why cross-validation is useful:

  1. Mitigating Overfitting: By training and testing on different data splits, you ensure the model doesn't overly adapt to a single train-test split.
  2. Better Utilization of Data: Since each data point gets to be in the test set exactly once, we're making the most of the available data.

We will use 5-fold cross-validation:

from scipy.stats import sem  # Standard Error of the Mean
from sklearn.model_selection import cross_val_score, KFold
from sklearn.pipeline import Pipeline

# Create a pipeline to properly fit and transform data for each fold.
pipeline = Pipeline([
        ('scaler', StandardScaler()),
        ('model', SGDClassifier())

# Applying 5-fold cross-validation
cv = KFold(5, shuffle=True, random_state=42)
scores = cross_val_score(pipeline, X_train, y_train, cv=cv)
print(f"Scores: {scores}")
print(f"Mean CV Accuracy: {scores.mean():.2f} (+/- {sem(scores):.2f})")

# Scores: [0.8        0.73333333 0.73333333 0.76666667 0.53333333]
# Mean CV Accuracy: 0.71 (+/- 0.05)


And that's it! We've trained a linear classification model using scikit-learn on the Iris plants dataset. We've also covered some essential methods for evaluating and visualizing your results. Knowing how to train linear models and analyze them is a foundational skill for any data scientist or machine learning engineer.

But don't stop here! Take this newfound knowledge and apply it to other datasets. Challenge yourself by tweaking parameters, exploring other algorithms, or delving into more advanced topics. And if you found this guide helpful, consider sharing it or leaving a comment (link to Medium for commenting). We love hearing from readers and are eager to know how you're putting these skills to use in your projects!

If you want to comment on the post, you can comment here!