使用Python的欠采样技术

使用Python实现欠采样技术

随着数字化领域的发展,大量数据正在从各种不同的源头生成和捕获。虽然这些宝贵的信息非常有价值,但往往反映出现实世界现象的不平衡分布。不平衡数据的问题不仅仅是一个统计挑战,它对数据驱动模型的准确性和可靠性有着深远的影响。

以金融行业中日益增长和普遍存在的欺诈检测为例。尽管我们想尽量避免欺诈行为因其高度破坏性而带来的损害,但机器(甚至人类)不可避免地需要从欺诈交易的例子中学习(尽管罕见),以将其与每天的合法交易数目区分开来。

欺诈和非欺诈交易之间的数据分布不平衡,给旨在检测此类异常活动的机器学习模型带来了重大挑战。如果没有适当处理数据不平衡,这些模型可能会偏向于将交易预测为合法,从而可能忽视欺诈的罕见情况。

医疗保健是另一个利用机器学习模型来预测不平衡结果的领域,例如癌症或罕见的遗传性疾病。这些结果的发生频率远低于它们的良性对应物。因此,基于这种不平衡数据训练的模型更容易出现错误的预测和诊断。这样的健康警报漏掉了模型的最初目的,即早期疾病检测。

这些只是几个例子,突出了数据不平衡的深远影响,即一类显著超过另一类。过采样和欠采样是两种常见的数据预处理技术,以平衡数据集,本文将重点介绍欠采样。

让我们讨论一些常用的欠采样给定分布的方法。

 

深入理解不平衡的缺点

 

我们先通过一个例子来更好地理解欠采样技术的重要性。下面的可视化展示了每个类别的相对数量对线性核的支持向量机的影响。以下代码和绘图引用自Kaggle notebook

import matplotlib.pyplot as pltfrom sklearn.svm import LinearSVCimport numpy as npfrom collections import Counterfrom sklearn.datasets import make_classificationdef create_dataset(    n_samples=1000, weights=(0.01, 0.01, 0.98), n_classes=3, class_sep=0.8, n_clusters=1):    return make_classification(        n_samples=n_samples,        n_features=2,        n_informative=2,        n_redundant=0,        n_repeated=0,        n_classes=n_classes,        n_clusters_per_class=n_clusters,        weights=list(weights),        class_sep=class_sep,        random_state=0,    )def plot_decision_function(X, y, clf, ax):    plot_step = 0.02    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1    xx, yy = np.meshgrid(        np.arange(x_min, x_max, plot_step), np.arange(y_min, y_max, plot_step)    )    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])    Z = Z.reshape(xx.shape)    ax.contourf(xx, yy, Z, alpha=0.4)    ax.scatter(X[:, 0], X[:, 1], alpha=0.8, c=y, edgecolor="k")fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))ax_arr = (ax1, ax2, ax3, ax4)weights_arr = (    (0.01, 0.01, 0.98),    (0.01, 0.05, 0.94),    (0.2, 0.1, 0.7),    (0.33, 0.33, 0.33),)for ax, weights in zip(ax_arr, weights_arr):    X, y = create_dataset(n_samples=1000, weights=weights)    clf = LinearSVC().fit(X, y)    plot_decision_function(X, y, clf, ax)    ax.set_title("Linear SVC with y={}".format(Counter(y)))

上面的代码从一个高度不平衡的数据集中生成四个不同分布的图。第二个和第三个图中包含了93%和69%的来自一类的实例,而最后一个图中的分布是完全平衡的,即每个类别贡献三分之一的实例。从最不平衡到最平衡的数据集的图像如下所示。在对这些数据进行SVM拟合时,第一个图中的超平面(高度不平衡)被推向图表的一侧,主要是因为算法平等地对待每个实例,无论其类别如何,并尝试以最大间隔来分离类别。因此,中心附近的大多数黄色实例将超平面推向角落,使算法错误地将少数类别分类错误。

随着数据分布变得更加平衡,算法成功地对所有感兴趣的类别进行分类。

总而言之,当一个数据集被一个或几个类别主导时,结果往往会导致具有更高错误分类的模型。然而,随着每个类别的观测分布趋向于均匀分割,分类器展现出减小的偏差。

在这种情况下,对黄色点进行欠采样是解决罕见类别问题导致模型错误的最简单解决方案。值得注意的是,并非所有数据集都会遇到这个问题,但对于那些遇到这个问题的数据集来说,纠正这种不平衡形成建模数据的关键初步步骤。

Imbalanced-Learn库

我们将使用Imbalanced-Learn Python库 (imbalanced-learn或imblearn)。我们可以使用pip来安装它:

pip install -U imbalanced-learn

实操!

让我们讨论并实验一些最流行的欠采样技术。假设您拥有一个二元分类的数据集,其中类别’0’明显超过类别’1’。

NearMiss欠采样

NearMiss是一种减少多数类样本数量接近少数类的欠采样技术。这将使任何使用空间分离或在两个类之间分割维度空间的算法能够进行清晰的分类。NearMiss有三个版本:

NearMiss-1: 多数类样本与三个最近少数类样本的最小平均距离。

NearMiss-2: 多数类样本与三个最远少数类样本的最小平均距离。

NearMiss-3: 多数类样本与每个少数类样本的最小距离。

让我们通过代码示例演示NearMiss-1欠采样算法:

# 导入所需的库和模块
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.datasets import make_classification
from imblearn.under_sampling import NearMiss

# 生成不同类权重的数据集
features, labels = make_classification(
    n_samples=1000,
    n_features=2,
    n_redundant=0,
    n_clusters_per_class=1,
    weights=[0.95, 0.05],
    flip_y=0,
    random_state=0,
)

# 打印类的分布情况
dist_classes = Counter(labels)
print("Undersampling前:")
print(dist_classes)

# 生成实例的散点图,按类别进行标记
for class_label, _ in dist_classes.items():
    instances = np.where(labels == class_label)[0]
    plt.scatter(features[instances, 0], features[instances, 1], label=str(class_label))
plt.legend()
plt.show()

# 设置欠采样方法
undersampler = NearMiss(version=1, n_neighbors=3)

# 将欠采样应用于数据集
features, labels = undersampler.fit_resample(features, labels)

# 打印新的类分布情况
dist_classes = Counter(labels)
print("Undersampling后:")
print(dist_classes)

# 生成实例的散点图,按类别进行标记
for class_label, _ in dist_classes.items():
    instances = np.where(labels == class_label)[0]
    plt.scatter(features[instances, 0], features[instances, 1], label=str(class_label))
plt.legend()
plt.show()

在 NearMiss() 类中将 version=1 更改为 version=2 或 version=3,以使用 NearMiss-2 或 NearMiss-3 欠采样算法。

 

 

NearMiss-2 选择两个类之间重叠区域的核心实例。通过 NeverMiss-3 算法,我们观察到与多数类区域重叠的少数类中的每个实例具有最多三个来自多数类的邻居。上面的代码示例中的属性 n_neighbors 定义了这一点。

 

简化最近邻居(CNN)规则

 

该方法首先将多数类的子集视为噪声。然后,它使用 1-最近邻算法来分类实例。如果来自多数类的实例被错误分类,则将其包含在子集中。该过程持续进行直到没有更多的实例被包含在子集中。

from imblearn.under_sampling import CondensedNearestNeighbourcnn = CondensedNearestNeighbour(random_state=42)X_res, y_res = cnn.fit_resample(X, y)

 

Tomek 链下采样

 

Tomek 链是相互靠近的不同类实例对。删除每对的多数类实例增加了两个类之间的空间,有助于分类过程。

from imblearn.under_sampling import TomekLinkstl = TomekLinks()X_res, y_res = tl.fit_resample(X, y)print('Original dataset shape:', Counter(y))print('Resample dataset shape:', Counter(y_res))

 

通过这些内容,我们深入了解了 Python 中的欠采样技术的基本方面,涵盖了三种显著的方法:Near Miss 欠采样、简化最近邻居和 Tomek 链下采样。

欠采样是解决机器学习中类别不平衡问题的重要数据处理步骤,还有助于改善模型性能和公平性。每种技术都具有独特的优势,并可以根据特定数据集和机器学习项目的目标进行定制。

本文全面介绍了欠采样方法及其在 Python 中的应用。希望它能够帮助您在处理机器学习项目中的类别不平衡挑战时做出明智的决策。

[Vidhi Chugh](https://vidhi-chugh.medium.com/)是一位人工智能战略家和数字转型领导者,致力于构建可扩展的机器学习系统。她是一位屡获殊荣的创新领导者、作者和国际演讲者。她的使命是使机器学习民主化,并打破术语壁垒,使每个人都能参与到这一变革中来。