02-机器学习基础: 监督学习——集成学习

张开发
2026/4/18 23:37:20 15 分钟阅读

分享文章

02-机器学习基础: 监督学习——集成学习
集成学习:随机森林一、为什么需要集成学习?1.1 弱分类器的困境importnumpyasnpimportmatplotlib.pyplotaspltfromsklearn.treeimportDecisionTreeClassifierfromsklearn.ensembleimportRandomForestClassifierfromsklearn.datasetsimportmake_moonsfromsklearn.model_selectionimporttrain_test_splitfromsklearn.metricsimportaccuracy_scoreimportwarnings warnings.filterwarnings('ignore')print("="*60)print("集成学习:三个臭皮匠,顶个诸葛亮")print("="*60)# 生成数据X,y=make_moons(n_samples=300,noise=0.25,random_state=42)X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.3,random_state=42)# 单棵决策树 vs 随机森林single_tree=DecisionTreeClassifier(max_depth=5,random_state=42)random_forest=RandomForestClassifier(n_estimators=50,max_depth=5,random_state=42)single_tree.fit(X_train,y_train)random_forest.fit(X_train,y_train)single_acc=single_tree.score(X_test,y_test)rf_acc=random_forest.score(X_test,y_test)print(f"\n单棵决策树准确率:{single_acc:.4f}")print(f"随机森林准确率:{rf_acc:.4f}")print(f"提升:{rf_acc-single_acc:.4f}")# 可视化对比fig,axes=plt.subplots(1,2,figsize=(14,5))defplot_decision_boundary(model,X,y,ax,title):x_min,x_max=X[:,0].min()-0.5,X[:,0].max()+0.5y_min,y_max=X[:,1].min()-0.5,X[:,1].max()+0.5xx,yy=np.meshgrid(np.linspace(x_min,x_max,200),np.linspace(y_min,y_max,200))Z=model.predict(np.c_[xx.ravel(),yy.ravel()])Z=Z.reshape(xx.shape)ax.contourf(xx,yy,Z,alpha=0.3,cmap='RdBu')ax.scatter(X[y==0,0],X[y==0,1],c='blue',alpha=0.5,s=15)ax.scatter(X[y==1,0],X[y==1,1],c='red',alpha=0.5,s=15)ax.set_title(title)ax.set_xlabel('特征1')ax.set_ylabel('特征2')ax.grid(True,alpha=0.3)plot_decision_boundary(single_tree,X_test,y_test,axes[0],f'单棵决策树\n准确率={single_acc:.3f}')plot_decision_boundary(random_forest,X_test,y_test,axes[1],f'随机森林\n准确率={rf_acc:.3f}')plt.suptitle('单棵决策树 vs 随机森林',fontsize=14)plt.tight_layout()plt.show()print("\n💡 集成学习的核心思想:")print(" 1. 多个弱分类器组合成强分类器")print(" 2. 不同分类器在不同区域有优势")print(" 3. 投票决定最终结果,减少方差")二、Bagging与随机森林原理2.1 Bootstrap采样defvisualize_bootstrap():"""可视化Bootstrap采样"""fig,axes=plt.subplots(1,2,figsize=(14,5))# 1. Bootstrap采样过程ax1=axes[0]ax1.axis('off')ax1.set_title('Bootstrap采样(有放回抽样)',fontsize=12)# 原始数据original_data=['A','B','C','D','E','F','G','H']n_samples=len(original_data)# 绘制原始数据fori,iteminenumerate(original_data):circle=plt.Circle((0.1+i*0.08,0.7),0.04,color='lightblue',ec='black')ax1.add_patch(circle)ax1.text(0.1+i*0.08,0.7,item,ha='center',va='center',fontsize=8)ax1.text(0.5,0.85,'原始数据集 (8个样本)',ha='center',fontsize=10)# Bootstrap采样结果np.random.seed(42)bootstrap_samples=[]foriinrange(3):sample=np.random.choice(original_data,n_samples,replace=True)bootstrap_samples.append(sample)y_pos=0.5fori,sampleinenumerate(bootstrap_samples):ax1.text(0.05,y_pos,f'采样{i+1}:',fontsize=9,fontweight='bold')forj,iteminenumerate(sample):ax1.text(0.15+j*0.06,y_pos,item,fontsize=8,bbox=dict(boxstyle='circle',facecolor='lightgreen',alpha=0.7))y_pos-=0.12ax1.text(0.5,0.15,'特点: 部分样本重复,部分样本未出现',ha='center',fontsize=9,bbox=dict(boxstyle='round',facecolor='lightyellow'))# 2. Bagging架构ax2=axes[1

更多文章