分类算法
问题
数据分析中常用的分类算法有哪些?
答案
常用分类算法对比
| 算法 | 可解释性 | 性能 | 适用场景 |
|---|---|---|---|
| 逻辑回归 | ⭐⭐⭐⭐⭐ | ⭐⭐⭐ | 线性可分、需要解释 |
| 决策树 | ⭐⭐⭐⭐ | ⭐⭐⭐ | 规则提取、小数据 |
| 随机森林 | ⭐⭐⭐ | ⭐⭐⭐⭐ | 通用分类、特征重要性 |
| XGBoost | ⭐⭐ | ⭐⭐⭐⭐⭐ | 竞赛、高精度需求 |
决策树
决策树通过一系列 if-else 规则进行分类,非常直观:
from sklearn.tree import DecisionTreeClassifier, export_text
# 预测用户是否会购买
X = df[['age', 'income', 'browse_count', 'cart_count']]
y = df['will_purchase']
model = DecisionTreeClassifier(max_depth=3, random_state=42)
model.fit(X, y)
# 输出决策规则(可直接给业务方看)
print(export_text(model, feature_names=list(X.columns)))
输出示例:
|--- cart_count <= 2.50
| |--- browse_count <= 5.50
| | |--- class: 0 (不购买)
| |--- browse_count > 5.50
| | |--- class: 1 (购买)
|--- cart_count > 2.50
| |--- class: 1 (购买)
随机森林
随机森林 = 多棵决策树投票,减少过拟合:
from sklearn.ensemble import RandomForestClassifier
import matplotlib.pyplot as plt
model = RandomForestClassifier(n_estimators=100, max_depth=5, random_state=42)
model.fit(X_train, y_train)
# 特征重要性(数据分析师最关注的输出)
importance = pd.Series(model.feature_importances_, index=X.columns)
importance.sort_values(ascending=True).plot(kind='barh', title='特征重要性')
plt.tight_layout()
plt.show()
特征重要性的价值
特征重要性告诉我们"哪些因素对结果影响最大",这是数据分析师向业务方汇报的核心输出。
XGBoost
from xgboost import XGBClassifier
model = XGBClassifier(
n_estimators=200,
max_depth=4,
learning_rate=0.1,
random_state=42
)
model.fit(X_train, y_train, eval_set=[(X_test, y_test)], verbose=False)
print(f'AUC = {roc_auc_score(y_test, model.predict_proba(X_test)[:, 1]):.4f}')
常见面试问题
Q1: 决策树和随机森林的区别?
答案:
- 决策树:单棵树,容易过拟合,但可解释性强
- 随机森林:多棵树投票(Bagging),泛化能力更强
- 随机森林通过随机选择特征子集 + 随机采样降低方差
Q2: 什么时候用逻辑回归,什么时候用 XGBoost?
答案:
| 场景 | 推荐 |
|---|---|
| 需要向业务方解释原因 | 逻辑回归 / 决策树 |
| 追求预测精度 | XGBoost |
| 数据量小 | 逻辑回归 |
| 特征多且非线性 | XGBoost / 随机森林 |