跳到主要内容

分类算法

问题

数据分析中常用的分类算法有哪些?

答案

常用分类算法对比

算法可解释性性能适用场景
逻辑回归⭐⭐⭐⭐⭐⭐⭐⭐线性可分、需要解释
决策树⭐⭐⭐⭐⭐⭐⭐规则提取、小数据
随机森林⭐⭐⭐⭐⭐⭐⭐通用分类、特征重要性
XGBoost⭐⭐⭐⭐⭐⭐⭐竞赛、高精度需求

决策树

决策树通过一系列 if-else 规则进行分类,非常直观:

from sklearn.tree import DecisionTreeClassifier, export_text

# 预测用户是否会购买
X = df[['age', 'income', 'browse_count', 'cart_count']]
y = df['will_purchase']

model = DecisionTreeClassifier(max_depth=3, random_state=42)
model.fit(X, y)

# 输出决策规则(可直接给业务方看)
print(export_text(model, feature_names=list(X.columns)))

输出示例:

|--- cart_count <= 2.50
| |--- browse_count <= 5.50
| | |--- class: 0 (不购买)
| |--- browse_count > 5.50
| | |--- class: 1 (购买)
|--- cart_count > 2.50
| |--- class: 1 (购买)

随机森林

随机森林 = 多棵决策树投票,减少过拟合:

from sklearn.ensemble import RandomForestClassifier
import matplotlib.pyplot as plt

model = RandomForestClassifier(n_estimators=100, max_depth=5, random_state=42)
model.fit(X_train, y_train)

# 特征重要性(数据分析师最关注的输出)
importance = pd.Series(model.feature_importances_, index=X.columns)
importance.sort_values(ascending=True).plot(kind='barh', title='特征重要性')
plt.tight_layout()
plt.show()
特征重要性的价值

特征重要性告诉我们"哪些因素对结果影响最大",这是数据分析师向业务方汇报的核心输出。

XGBoost

from xgboost import XGBClassifier

model = XGBClassifier(
n_estimators=200,
max_depth=4,
learning_rate=0.1,
random_state=42
)
model.fit(X_train, y_train, eval_set=[(X_test, y_test)], verbose=False)

print(f'AUC = {roc_auc_score(y_test, model.predict_proba(X_test)[:, 1]):.4f}')

常见面试问题

Q1: 决策树和随机森林的区别?

答案

  • 决策树:单棵树,容易过拟合,但可解释性强
  • 随机森林:多棵树投票(Bagging),泛化能力更强
  • 随机森林通过随机选择特征子集 + 随机采样降低方差

Q2: 什么时候用逻辑回归,什么时候用 XGBoost?

答案

场景推荐
需要向业务方解释原因逻辑回归 / 决策树
追求预测精度XGBoost
数据量小逻辑回归
特征多且非线性XGBoost / 随机森林

相关链接