本文通过数据科学和AI的方法,分析挖掘人力资源流失问题,构建基于机器学习的解决方案,并通过对AI模型的反向解释,深入理解导致人员流失的主要因素。 ?? 作者:韩信子@ShowMeAI
?? 数据分析实战系列:https://www.showmeai.tech/tutorials/40
?? 机器学习实战系列:https://www.showmeai.tech/tutorials/41
?? 本文地址:https://www.showmeai.tech/article-detail/308
?? 声明:版权所有,转载请联系平台与作者并注明出处
?? 收藏ShowMeAI查看更多精彩内容
人力资源是组织的一个部门,负责处理员工的招聘、培训、管理和福利。一个组织每年都会雇佣几名员工,并投入大量时间、金钱和资源来提高员工的绩效和效率。每家公司都希望能够吸引和留住优秀的员工,失去一名员工并再次雇佣一名新员工的成本是非常高的,HR部门需要知道雇用和留住重要和优秀员工的核心因素是什么,那那么可以做得更好。
在本项目中,ShowMeAI 带大家通过数据科学和AI的方法,分析挖掘人力资源流失问题,并基于机器学习构建解决问题的方法,并且,我们通过对AI模型的反向解释,可以深入理解导致人员流失的主要因素,HR部门也可以根据分析做出正确的决定。
本篇涉及到的数据集大家可以通过 ShowMeAI 的百度网盘地址获取。
?? 实战数据集下载(百度网盘):公众号『ShowMeAI研究中心』回复『实战』,或者点击 这里 获取本文 [17]人力资源流失场景机器学习建模与调优 『HR-Employee-Attrition 数据集』
? ShowMeAI官方GitHub:https://github.com/ShowMeAI-Hub
和 ShowMeAI 之前介绍过的所有AI项目一样,我们需要先对场景数据做一个深度理解,这就是我们提到的EDA(Exploratory Data Analysis,探索性数据分析)过程。
EDA部分涉及的工具库,大家可以参考ShowMeAI制作的工具库速查表和教程进行学习和快速使用。
??数据科学工具库速查表 | Pandas 速查表
??数据科学工具库速查表 | Seaborn 速查表
??图解数据分析:从入门到精通系列教程
我们本次使用到的数据集字段基本说明如下:
| 列名 | 含义 |
|---|---|
| Age | 年龄 |
| Attrition | 离职 |
| BusinessTravel | 出差:0-不出差、1-偶尔出差、2-经常出差 |
| Department | 部门:1-人力资源、2-科研、3-销售 |
| DistanceFromHome | 离家距离 |
| Education | 教育程度:1-大学一下、2-大学、3-学士、4-硕士、5-博士 |
| EducationField | 教育领域 |
| EnvironmentSatisfaction | 环境满意度 |
| Gender | 性别:1-Mae男、0- Female女 |
| Joblnvolvement | 工作投入 |
| JobLevel | 职位等级 |
| JobRole | 工作岗位 |
| JobSatisfaction | 工作满意度 |
| Maritalstatus | 婚姻状况:0- Divorced离婚、1- Single未婚、2-已婚 |
| Monthlylncome | 月收入 |
| NumCompaniesWorked | 服务过几家公司 |
| OverTime | 加班 |
| RelationshipSatisfaction | 关系满意度 |
| StockOptionLevel | 股权等级 |
| TotalworkingYears | 总工作年限 |
| TrainingTimesLastYear | 上一年培训次数 |
| WorkLifeBalance | 工作生活平衡 |
| YearsAtCompany | 工作时长 |
| YearsInCurrentRole | 当前岗位在职时长 |
| YearsSinceLastPromotion | 上次升职时间 |
| YearsWithCurrManager | 和现任经理时长 |
下面我们先导入所需工具库、加载数据集并查看数据基本信息:
import pandas as pdimport numpy as npimport matplotlib as mplimport matplotlib.pyplot as pltimport seaborn as snssns.set_style("darkgrid")import warningswarnings.filterwarnings("ignore")pd.set_option('display.max_columns',100)print("import complete")# 读取数据data = pd.read_csv("HR-Employee-Attrition.csv") data.head()查看前 5 条数据记录后,我们了解了一些基本信息:
① 数据包含『数值型』和『类别型』两种类型的特征。
② 有不少离散的数值特征。
接下来我们借助工具库进一步探索数据。
# 字段、类型、缺失情况data.info()我们使用命令 data.info``() 来获取数据的信息,包括总行数(样本数)和总列数(字段数)、变量的数据类型、数据集中非缺失的数量以及内存使用情况。
从数据集的信息可以看出,一共有 35 个特征,Attrition 是目标字段,26 个变量是整数类型变量,9 个是对象类型变量。
我们先来做一下缺失值检测与处理,缺失值的存在可能会降低模型效果,也可能导致模型出现偏差。
# 查看缺失值情况data.isnull().sum()从结果可以看出,数据集中没有缺失值。
因为目标特征“Attrition”是一个类别型变量,为了分析方便以及能够顺利建模,我们对它进行类别编码(映射为整数值)。
#since Attrition is a categotical in nature so will be mapping it with integrs variables for further analysisdata.Attrition = data.Attrition.map({"Yes":1,"No":0})接下来,我们借助于pandas的describe函数检查数值特征的统计摘要:
#checking statistical summarydata.describe().T注意这里的“.T”是获取数据帧的转置,以便更好地分析。
从统计摘要中,我们得到数据的统计信息,包括数据的中心趋势——平均值、中位数、众数和散布标准差和百分位数,最小值和最大值等。
我们进一步对数值型变量进行分析
# 选出数值型特征numerical_feat = data.select_dtypes(include=['int64','float64'])numerical_featprint(numerical_feat.columns)print("No. of numerical variables :",len(numerical_feat.columns))print("Number of unique values \n",numerical_feat.nunique())我们有以下观察结论:
① 共有27个数值型特征变量
② 月收入、日费率、员工人数、月费率等为连续数值
③ 其余变量为离散数值(即有固定量的候选取值)
我们借助于 seaborn 工具包中的分布图方法 sns.distplot() 来绘制数值分布图
# 数据分析&分布绘制plt.figure(figsize=(25,30))plot = 1for var in numerical_feat: plt.subplot(9,3,plot) sns.distplot(data[var],color='skyblue') plot+=1plt.show()通过以上分析,我们获得以下一些基本观察和结论:
接下来我们对目标变量做点分析:
# 目标变量分布sns.countplot('Attrition',data=data)plt.title("Distribution of Target Variable")plt.show()print(data.Attrition.value_counts())我们可以看到数据集中存在类别不平衡问题(流失的用户占比少)。类别不均衡情况下,我们要选择更有效的评估指标(如auc可能比accuracy更有效),同时在建模过程中做一些优化处理。
我们分别对各个字段和目标字段进行联合关联分析。
# Age 与 attritionage=pd.crosstab(data.Age,data.Attrition)age.div(age.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(14,7),cmap='spring')plt.title("Age vs Attrition",fontsize=20)plt.show()# Distance from home 与 attritiondist=pd.crosstab(data.DistanceFromHome,data.Attrition)dist.div(dist.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7))plt.title("Distance From Home vs Attrition",fontsize=20)plt.show()# Education 与 Attritionedu=pd.crosstab(data.Education,data.Attrition)edu.div(edu.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7))plt.title("Education vs Attrition",fontsize=20)plt.show()# Environment Satisfaction 与 Attritionesat=pd.crosstab(data.EnvironmentSatisfaction,data.Attrition)esat.div(esat.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='BrBG')plt.title("Environment Satisfaction vs Attrition",fontsize=20)plt.show()# Job Involvement 与 Attritionjob_inv=pd.crosstab(data.JobInvolvement,data.Attrition)job_inv.div(job_inv.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='Spectral')plt.title("Job Involvement vs Attrition",fontsize=20)plt.show()# Job Level 与 Attritionjob_lvl=pd.crosstab(data.JobLevel,data.Attrition)job_lvl.div(job_lvl.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='prism_r')plt.title("Job Level vs Attrition",fontsize=20)plt.show()# Job Satisfaction 与 Attritionjob_sat=pd.crosstab(data.JobSatisfaction,data.Attrition)job_sat.div(job_sat.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='inferno')plt.title("Job Satisfaction vs Attrition",fontsize=20)plt.show()# Number of Companies Worked 与 Attritionnum_org=pd.crosstab(data.NumCompaniesWorked,data.Attrition)num_org.div(num_org.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='cividis_r')plt.title("Number of Companies Worked vs Attrition",fontsize=20)plt.show()# Percent Salary Hike 与 Attritionsal_hike_percent=pd.crosstab(data.PercentSalaryHike,data.Attrition)sal_hike_percent.div(sal_hike_percent.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='RdYlBu')plt.title("Percent Salary Hike vs Attrition",fontsize=20)plt.show()# Performance Rating 与 Attritionperformance=pd.crosstab(data.PerformanceRating,data.Attrition)performance.div(performance.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='viridis_r')plt.title("Performance Rating vs Attrition",fontsize=20)plt.show()# Relationship Satisfaction 与 Attritionrel_sat=pd.crosstab(data.RelationshipSatisfaction,data.Attrition)rel_sat.div(rel_sat.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='brg_r')plt.title("Relationship Satisfaction vs Attrition",fontsize=20)plt.show()# Stock Option Level 与 Attritionstock_opt=pd.crosstab(data.StockOptionLevel,data.Attrition)stock_opt.div(stock_opt.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='Accent')plt.title("Stock Option Level vs Attrition",fontsize=20)plt.show()# Training Times Last Year 与 Attritiontr_time=pd.crosstab(data.TrainingTimesLastYear,data.Attrition)tr_time.div(tr_time.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='coolwarm')plt.title("Training Times Last Year vs Attrition",fontsize=20)plt.show()# Work Life Balance 与 Attritionwork=pd.crosstab(data.WorkLifeBalance,data.Attrition)work.div(work.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='gnuplot')plt.title("Work Life Balance vs Attrition",fontsize=20)plt.show()# Years With Curr Manager 与 Attritioncurr_mang=pd.crosstab(data.YearsWithCurrManager,data.Attrition)curr_mang.div(curr_mang.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='OrRd_r')plt.title("Years With Curr Manager vs Attrition",fontsize=20)plt.show()# Years Since Last Promotion 与 Attritionprom=pd.crosstab(data.YearsSinceLastPromotion,data.Attrition)prom.div(prom.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='PiYG_r')plt.title("Years Since Last Promotion vs Attrition",fontsize=20)plt.show()# Years In Current Role 与 Attritionrole=pd.crosstab(data.YearsInCurrentRole,data.Attrition)role.div(role.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='terrain')plt.title("Years In Current Role vs Attrition",fontsize=20)plt.show()这些堆积条形图显示了员工流失情况与各个字段取值的关系,从上图我们可以得出以下基本结论:
现在我们对类别型特征进行分析,在这里我们使用饼图和堆积条形图来分析它们的分布以及它们和目标变量的相关性。
# 分析Buisness Travel colors=['red','green','blue']size = data.BusinessTravel.value_counts().valuesexplode_list=[0,0.05,0.1]plt.figure(figsize=(15,10))plt.pie(size,labels=None,explode=explode_list,colors=colors,autopct="%1.1f%%",pctdistance=1.15)plt.title("Business Travel",fontsize=15)plt.legend(labels=['Travel_Rarely','Travel_Frequently','Non-Travel'],loc='upper left') plt.show().valuesexplode_list=[0,0.05,0.06]plt.figure(figsize=(15,10))plt.pie(size,labels=None,explode=explode_list,colors=colors,autopct="%1.1f%%",pctdistance=1.1)plt.title("Department",fontsize=15)plt.legend(labels=['Sales','Research & Development','Human Resources'],loc='upper left') plt.show().valuesexplode_list=[0,0.05,0.05,0.08,0.08,0.1]plt.figure(figsize=(15,10))plt.pie(size,labels=None,explode=explode_list,colors=colors,autopct="%1.1f%%",pctdistance=1.1)plt.title("Education Field",fontsize=15)plt.legend(labels=['Life Sciences','Other','Medical','Marketing','Technical Degree','Human Resources'],loc='upper left') plt.show().valuesexplode_list=[0,0.05,0.05,0.05,0.08,0.08,0.08,0.1,0.1]plt.figure(figsize=(15,10))plt.pie(size,labels=None,explode=explode_list,colors=colors,autopct="%1.1f%%",pctdistance=1.1)plt.title("Job Role",fontsize=15)plt.legend(labels=['Sales Executive','Research Scientist','Laboratory Technician','Manufacturing Director','Healthcare Representative','Manager','Sales Representative','Research Director','Human Resources'],loc='upper left') plt.show())plt.title('Gender distribution',fontsize=15)sns.countplot('Gender',data=data,palette='magma')trav.div(trav.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='Set1')plt.title("Business Travel vs Attrition",fontsize=20)plt.show()# Department 与 Attritiondept = pd.crosstab(data.Department,data.Attrition)dept.div(dept.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='Set1')plt.title("Department vs Attrition",fontsize=20)plt.show()# Education Field 与 Attritionedu_f = pd.crosstab(data.EducationField,data.Attrition)edu_f.div(edu_f.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='Set1')plt.title("Education Field vs Attrition",fontsize=20)plt.show()# Job Role 与 Attritionjobrole = pd.crosstab(data.JobRole,data.Attrition)jobrole.div(jobrole.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='Set1')plt.title("Job Role vs Attrition",fontsize=20)plt.show()# Marital Status 与 Attritionmary = pd.crosstab(data.MaritalStatus,data.Attrition)mary.div(mary.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='Set1')plt.title("Marital Status vs Attrition",fontsize=20)plt.show()# gender 与 Attritionplt.figure(figsize=(10,9))plt.title('Gender distribution',fontsize=15)sns.countplot('Gender',data=data,palette='magma'))sns.heatmap(data.corr(method='spearman'),annot=True,cmap='Accent')plt.title('Correlation of features',fontsize=20)plt.show()# 相关度排序plt.figure(figsize=(15,9))correlation = data . corr(method='spearman')correlation.Attrition.sort_values(ascending=False).drop('Attrition').plot.bar(color='r')plt.title('Correlation of independent features with target feature',fontsize=20)plt.show()下面我们检测一下数据集中的异常值,在这里,我们使用箱线图来可视化分布并检测异常值。
# 绘制箱线图plot=1plt.figure(figsize=(15,30))for i in numerical_feat.columns: plt.subplot(9,3,plot) sns.boxplot(data[i],color='navy') plt.xlabel(i) plot+=1plt.show()箱线图显示数据集中有不少异常值,不过这里的异常值主要是因为离散变量(可能是取值较少的候选),我们将保留它们(不然会损失掉这些样本信息),不过我们注意到月收入的异常值比较奇怪,这可能是由于数据收集错误造成的,可以清洗一下。
关于机器学习特征工程,大家可以参考 ShowMeAI 整理的特征工程最全解读教程。
??机器学习实战 | 机器学习特征工程最全解读
下面我们来完成特征工程的部分,从原始数据中抽取强表征的信息,以便模型能更直接高效地挖掘和建模。
我们在EDA过程中发现 MonthlyIncome、JobLevel 和 YearsAtCompany 以及 YearsInCurrentRole 高度相关,可能会带来多重共线性问题,我们会做一些筛选,同时我们会删除一些与 EmployeeCount、StandardHours 等变量不相关的特征,并剔除一些对预测不重要的特征。
dataset = data.copy()# 删除与目标相关性低的Employee count 和 standard hours特征dataset.drop(['EmployeeCount','StandardHours'],inplace=True,axis=1)dataset.head(2)下面我们对类别型特征进行编码,包括数字映射与独热向量编码。
# 按照出差的频度进行编码dataset.BusinessTravel = dataset.BusinessTravel.replace({ 'Non-Travel':0,'Travel_Rarely':1,'Travel_Frequently':2 })# 性别与overtime编码dataset.Gender = dataset.Gender.replace({'Male':1,'Female':0})dataset.OverTime = dataset.OverTime.replace({'Yes':1,'No':0})# 独热向量编码 new_df = pd.get_dummies(data=dataset,columns=['Department','EducationField','JobRole', 'MaritalStatus'])new_df处理与转换后的数据如下所示:
在前面的数据探索分析过程中,我们发现目标变量是类别不平衡的,因此可能会导致模型偏向多数类而带来偏差。我们在这里会应用过采样技术 SMOTE(合成少数类别的样本补充)来处理数据集中的类别不平衡问题。
我们把数据先切分为特征和标签,处理之前的标签类别比例如下:
# 切分特征和标签X = new_df.drop(['Attrition'],axis=1)Y = new_df.Attrition# 标签01取值比例sns.countplot(data=new_df,x=Y,palette='Set1')plt.show()print(Y.value_counts())x,y = sm.fit_resample(X,Y)print(x.shape," \t ",y.shape)# (2466, 45) (2466,)过采样后
# 过采样之后的比例sns.countplot(data=new_df,x=y,palette='Set1')plt.show()print(y.value_counts())x_scaled = scaler.fit_transform(x)x_scaled = pd.DataFrame(x_scaled, columns=x.columns)x_scaled处理后我们的数据集看起来像这样
所有取值都已调整到 0 -1 的幅度范围内。
通常在特征工程之后,我们会得到非常多的特征,太多特征会带来模型训练性能上的问题,不相关的差特征甚至会拉低模型的效果。
我们很多时候会进行特征重要度分析的工作,筛选和保留有效特征,而对其他特征进行剔除。我们先将数据集拆分为训练集和测试集,再基于互信息判定特征重要度。
## 训练集测试集切分from sklearn.model_selection import train_test_splitxtrain,xtest,ytrain,ytest = train_test_split(x_scaled,y,test_size=0.3,random_state=1)我们使用 sklearn.feature_selection 类中的mutual_info_classif 方法来获得特征重要度。Mutual _info_classif的工作原理是类似信息增益。
from sklearn.feature_selection import mutual_info_classifmutual_info = mutual_info_classif(xtrain,ytrain)mutual_info下面我们绘制一下特征重要性
mutual_info = pd.Series(mutual_info)mutual_info.index = xtrain.columnsmutual_info.sort_values(ascending=False)plt.title("Feature Importance",fontsize=20)mutual_info.sort_values().plot(kind='barh',figsize=(12,9),color='r')plt.show()当然,实际判定特征重要度的方式有很多种,甚至结果也会有一些不同,我们只是基于这个步骤,进行一定的特征筛选,把最不相关的特征剔除。
关于建模与评估,大家可以参考 ShowMeAI 的机器学习系列教程与模型评估基础知识文章。
??图解机器学习算法:从入门到精通系列教程
??图解机器学习算法(1) | 机器学习基础知识
??图解机器学习算法(2) | 模型评估方法与准则
好,我们前序工作就算完毕啦!下面要开始构建模型了。在建模之前,有一件非常重要的事情,是我们需要选择合适的评估指标对模型进行评估,这能给我们指明模型优化的方向,我们在这里,针对分类问题,尽量覆盖地选择了下面这些评估指标
我们这里选用了8个模型构建baseline,并应用交叉验证以获得对模型无偏的评估结果。
# 导入工具库from sklearn.linear_model import LogisticRegressionfrom sklearn.tree import DecisionTreeClassifierfrom sklearn.svm import SVCfrom sklearn.neighbors import KNeighborsClassifierfrom sklearn.naive_bayes import BernoulliNBfrom sklearn.ensemble import RandomForestClassifierfrom sklearn.ensemble import AdaBoostClassifierfrom sklearn.ensemble import GradientBoostingClassifierfrom sklearn.model_selection import cross_val_score,cross_validatefrom sklearn.metrics import classification_report,confusion_matrix,accuracy_score,plot_roc_curve,roc_curve,auc,roc_auc_score,precision_score,r# 初始化baseline模型(使用默认参数)LR = LogisticRegression()KNN = KNeighborsClassifier()SVC = SVC()DTC = DecisionTreeClassifier()BNB = BernoulliNB()RTF = RandomForestClassifier()ADB = AdaBoostClassifier()GB = GradientBoostingClassifier()# 构建模型列表models = [("Logistic Regression ",LR), ("K Nearest Neighbor classifier ",KNN), ("Support Vector classifier ",SVC), ("Decision Tree classifier ",DTC), ("Random forest classifier ",RTF), ("AdaBoost classifier",ADB), ("Gradient Boosting classifier ",GB), ("Naive Bayes classifier",BNB)]接下来我们遍历这些模型进行训练和评估:
for name,model in models: model.fit(xtrain,ytrain) print(name," trained")# 遍历评估train_scores=[]test_scores=[]Model = []for name,model in models: print("******",name,"******") train_acc = accuracy_score(ytrain,model.predict(xtrain)) test_acc = accuracy_score(ytest,model.predict(xtest)) print('Train score : ',train_acc) print('Test score : ',test_acc) train_scores.append(train_acc) test_scores.append(test_acc) Model.append(name)# 不同的评估准则precision_ =[]recall_ = []f1score = []rocauc = []for name,model in models: print("******",name,"******") cm = confusion_matrix(ytest,model.predict(xtest)) print("\n",cm) fpr,tpr,thresholds=roc_curve(ytest,model.predict(xtest)) roc_auc= auc(fpr,tpr) print("\n","ROC_AUC_SCORE : ",roc_auc) rocauc.append(roc_auc) print(classification_report(ytest,model.predict(xtest))) precision = precision_score(ytest, model.predict(xtest)) print('Precision: ', precision) precision_.append(precision) recall = recall_score(ytest, model.predict(xtest)) print('Recall: ', recall) recall_.append(recall) f1 = f1_score(ytest, model.predict(xtest)) print('F1 score: ', f1) f1score.append(f1) plt.figure(figsize=(10,20)) plt.subplot(211) print(sns.heatmap(cm,annot=True,fmt='d',cmap='Accent')) plt.subplot(212) plt.plot([0,1],'k--') plt.plot(fpr,tpr) plt.xlabel('false positive rate') plt.ylabel('true positive rate') plt.show()我们把所有的评估结果汇总,得到一个模型结果对比表单
# 构建一个Dataframe存储所有模型的评估指标evaluate = pd.DataFrame({})evaluate['Model'] = Modelevaluate['Train score'] = train_scoresevaluate['Test score'] = test_scoresevaluate['Precision'] = precision_evaluate['Recall'] = recall_evaluate['F1 score'] = f1scoreevaluate['Roc-Auc score'] = rocaucevaluate我们从上述baseline模型的汇总评估结果里看到:
我们要看一下最终的交叉验证得分情况
# 查看交叉验证得分for name,model in models: print("******",name,"******") cv_= cross_val_score(model,x_scaled,y,cv=5).mean() print(cv_)从交叉验证结果上看,随机森林表现最优,我们把它选为最佳模型,并将进一步对它进行调优以获得更高的准确性。
关于建模与评估,大家可以参考ShowMeAI的相关文章。
??深度学习教程(7) | 网络优化:超参数调优、正则化、批归一化和程序框架
我们刚才建模过程,使用的都是模型的默认超参数,实际超参数的取值会影响模型的效果。我们有两种最常用的方法来进行超参数调优:
下面我们演示使用随机搜索调参优化。
from sklearn.model_selection import RandomizedSearchCVparams = {'n_estimators': [int(x) for x in np.linspace(start = 100, stop = 1200, num = 12)], 'criterion':['gini','entropy'], 'max_features': ['auto', 'sqrt'], 'max_depth': [int(x) for x in np.linspace(5, 30, num = 6)], 'min_samples_split': [2, 5, 10, 15, 100], 'min_samples_leaf': [1, 2, 5, 10] }random_search=RandomizedSearchCV(RTF,param_distributions=params,n_jobs=-1,cv=5,verbose=5)random_search.fit(xtrain,ytrain)拟合随机搜索后,我们取出最佳参数和最佳估计器。
random_search.best_params_random_search.best_estimator_我们对最佳估计器进行评估
# 最终模型final_mod = RandomForestClassifier(max_depth=10, max_features='sqrt', n_estimators=500)final_mod.fit(xtrain,ytrain)final_pred = final_mod.predict(xtest)print("Accuracy Score",accuracy_score(ytest,final_pred))cross_val = cross_val_score(final_mod,x_scaled,y,scoring='accuracy',cv=5).mean()print("Cross val score",cross_val)plot_roc_curve(final_mod,xtest,ytest)我们可以看到,超参数调优后:
最后我们对模型进行存储,以便后续使用或者部署上线。
import joblibjoblib.dump(final_mod,'hr_attrition.pkl')# ['hr_attrition.pkl']