test_model_bias.py 11.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import yaml
from pathlib import Path
import pandas as pd
import numpy as np
import paths as pt
from sklearn.model_selection import train_test_split
import xgboost as xgb
import seaborn as sns
from sklearn.model_selection import StratifiedKFold
from utility import metrics
from tools import data_loader, file_writer, file_reader
from sklearn.metrics import accuracy_score, precision_score
from sklearn.metrics import recall_score, roc_auc_score
import matplotlib.pyplot as plt

def main():
    # Load settings
18
    with open(Path.joinpath(pt.CONFIGS_DIR, "fall_emb.yaml"), 'r') as stream:
19
20
21
        settings = yaml.safe_load(stream)
    
    protected_col_name = "Gender_Male"
22
    y_col_name="Fall"
23
24

    # Load the data
25
    file_name = "fall_emb.csv"
26
    dl = data_loader.AlarmDataLoader(file_name, settings).load_data()
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    X, y = dl.get_data()
    X = X.drop(['Gender_Female'], axis=1)

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3,
                                                        stratify=y, random_state=0)
    neg, pos = np.bincount(y)
    scale_pos_weight = neg / pos

    params = {"n_estimators": 400,
              "objective": "binary:logistic",
              "scale_pos_weight": scale_pos_weight,
              "use_label_encoder": False,
              "learning_rate": 0.1,
              "eval_metric": "logloss",
41
              "random_state": 0
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    }

    model = xgb.XGBClassifier(**params)
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)
    df_test = pd.DataFrame([],columns=list(X.columns)+[y_col_name]+["output"]+["output_prob"])
    
    i=0
    y_valid_pred = 0*y
    valid_acc, valid_pre, valid_recall, valid_roc_auc = list(), list(), list(), list()
    for train_index, valid_index in skf.split(X_train, y_train):
        
        X_train_split, X_valid_split = X_train.iloc[train_index,:], X_train.iloc[valid_index,:]
        y_train_split, y_valid_split = y_train.iloc[train_index], y_train.iloc[valid_index]
        optimize_rounds = True
        early_stopping_rounds = 50
        
        if optimize_rounds:
            eval_set=[(X_valid_split, y_valid_split)]
            fit_model = model.fit(X_train_split, y_train_split,
                                    eval_set=eval_set,
                                    eval_metric=metrics.gini_xgb,
                                    early_stopping_rounds=early_stopping_rounds,
64
                                    verbose=False)
65
66
67
68
69
70
71
72
73
        else:
            fit_model = model.fit(X_train_split, y_train_split)
            
        pred = fit_model.predict_proba(X_valid_split)[:,1]
        y_valid_pred.iloc[valid_index] = pred
        y_valid_scores = (y_valid_pred.iloc[valid_index] > 0.5)
        
        # Save data
        y_true_pd=y_valid_split.to_frame().reset_index(drop=True)
74
75
        y_valid_scores = y_valid_scores.apply(lambda x: 1 if x == True else 0).to_frame()
        y_pred_pd=y_valid_scores.reset_index(drop=True).rename(columns={y_col_name : "output"})
76
77
        y_pred_prob_pd = pd.DataFrame(pred, columns = ["output_prob"])
        
78
79
        df_subset = pd.concat([X_valid_split.reset_index(drop=True), y_true_pd,
                               y_pred_pd, y_pred_prob_pd], axis=1)
80
81
82
        df_test = df_test.append(df_subset, ignore_index=True)

        # Save metrics
83
84
85
86
        df_evaluate_proc = metrics.get_cm_by_protected_variable(df_subset, protected_col_name,
                                                                y_col_name, "output")
        file_writer.write_csv(df_evaluate_proc, pt.INTERIM_DATA_DIR, "model"+str(i)
                              + "_" + protected_col_name + ".csv")
87
88
89
            
        df_evaluate_together = df_subset.copy()
        df_evaluate_together[protected_col_name] = "all"
90
91
92
93
        df_evaluate_all = metrics.get_cm_by_protected_variable(df_evaluate_together, protected_col_name,
                                                               y_col_name, "output")
        file_writer.write_csv(df_evaluate_all, pt.INTERIM_DATA_DIR, "model"+str(i)
                              + "_" + protected_col_name + "_all.csv")
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
            
        valid_acc.append(accuracy_score(y_valid_split, y_valid_scores))
        valid_pre.append(precision_score(y_valid_split, y_valid_scores))
        valid_recall.append(recall_score(y_valid_split, y_valid_scores))
        valid_roc_auc.append(roc_auc_score(y_valid_split, y_valid_pred.iloc[valid_index]))
        
        i=i+1
    
    file_writer.write_csv(df_test, pt.INTERIM_DATA_DIR, "all_test_data.csv")
    
    # Evaluate
    y_pred = model.predict(X_test)
    y_proba = model.predict_proba(X_test)[:,1]
    print(f"Accuracy: {np.around(accuracy_score(y_test, y_pred), decimals=3)}")
    print(f"Precision: {np.around(precision_score(y_test, y_pred), decimals=3)}")
    print(f"Recall: {np.around(recall_score(y_test, y_pred), decimals=3)}")
    print(f"ROC AUC: {np.around(roc_auc_score(y_test, y_proba), decimals=3)}\n")
    
    # Save the confusion data for all
    column_names = ["Group", "ML", "Measure", "Value"]
114
115
    measures = ['FPR', 'FNR', 'ACC', 'F1', 'FDR', 'LRminus', 'LRplus',
                'NPV', 'PPV', 'TNR', 'TPR','TP','TN','FN', 'FP']
116
117
118
119
120
    df_out = pd.DataFrame(columns=column_names)
    
    for i in range(5):        
        data = file_reader.read_csv(pt.INTERIM_DATA_DIR, f'model{i}_{protected_col_name}_all.csv')
        for group in ["all"]:
121
            for measure in measures:
122
                value = float(data[data[protected_col_name] == group][measure])
123
124
                df_out=df_out.append({'Group': group, "ML":"XGBoost"+str(i),
                                      "Measure":measure, "Value":value}, ignore_index=True)
125
126
    
    file_writer.write_csv(df_out, pt.INTERIM_DATA_DIR, 'XGBoost_metrics_crossvalidated_all.csv')
127
128
    global_all_bar=sns.barplot(data=df_out[df_out["Measure"].isin(["FPR","FNR","TPR","TNR"])],
                               x="Group", y="Value", ci=95,hue="Measure")
129
130
131
132
133
134
135
136
137
138
    global_all_bar.set_title('All')
    global_all_bar.get_figure().savefig(Path.joinpath(pt.REPORTS_PLOTS_DIR, f"{protected_col_name}_barplot_all.pdf"))
    
    # Save the confusion data for proc
    column_names = ["Group", "ML", "Measure", "Value"]
    df_out = pd.DataFrame(columns=column_names)
    
    for i in range(5):
        data = file_reader.read_csv(pt.INTERIM_DATA_DIR, f'model{i}_{protected_col_name}.csv')
        for group in [0.0, 1.0]:
139
            for measure in measures:
140
                value=float(data[data[protected_col_name]==group][measure])
141
142
                df_out=df_out.append({'Group': group,"ML":"XGBoost"+str(i),
                                      "Measure":measure,"Value":value}, ignore_index=True)
143
144
    
    file_writer.write_csv(df_out, pt.INTERIM_DATA_DIR, f'XGBoost_metrics_crossvalidated_{protected_col_name}.csv')
145
146
    global_proc_bar=sns.barplot(data=df_out[df_out["Measure"].isin(["FPR","FNR","TPR","TNR"])],
                                x="Group", y="Value", ci=95,hue="Measure")
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    global_proc_bar.set_title(f'Proctected: {protected_col_name}')
    global_all_bar.get_figure().savefig(Path.joinpath(pt.REPORTS_PLOTS_DIR, "barplot_proc.pdf"))
    
    # Save the confusion data for metrics
    column_names = ["Gender", "TPR", "FPR", "TNR", "FNR", "Model"]
    df_out = pd.DataFrame(columns=column_names)
    for i in range(5):        
        data = file_reader.read_csv(pt.INTERIM_DATA_DIR, f'model{i}_{protected_col_name}.csv')
        for group in [0.0, 1.0]:
            measures = data[data[protected_col_name] == group]
            df_out=df_out.append({'Gender': "Male" if group == 1.0 else "Female",
                                  "TPR": float(measures['TPR']), "FPR": float(measures['FPR']),
                                  "TNR": float(measures['TNR']), "FNR": float(measures['FNR']),
                                  "Model": "XGBoost"}, ignore_index=True)
    
    file_writer.write_csv(df_out, pt.INTERIM_DATA_DIR, 'XGBoost_gender.csv')
    
    df = df_out
    all_data_gender = df.melt(id_vars=["Gender", "Model"], var_name="Metric", value_name="Value")
    
    fig = plt.figure(constrained_layout=True,figsize=(15,8))
    gs = plt.GridSpec(2, 8, figure=fig)
    gs.update(wspace=0.5)

    ax=[]

    ax.append( fig.add_subplot(gs[0, 2:4]))
    # identical to ax1 = plt.subplot(gs.new_subplotspec((0, 0), colspan=3))
    ax.append( fig.add_subplot(gs[0, 4:6]))
    ax.append(fig.add_subplot(gs[1, 1:3]))
    ax.append(fig.add_subplot(gs[1, 3:5]))
    ax.append( fig.add_subplot(gs[1, 5:7]))

    palette_custom ={"Female": "C0", "Male": "C1"}
    gender_order=["Female", "Male"]
    list_of_models=["XGBoost"]

    for i, v in enumerate(list_of_models):
        
        filter1 = all_data_gender["Model"] == v
        
        ax[i].set_title(v, size=15)
        ax[i].set_ylim([0, 1])
        ax[i].grid(axis='x')
        sns.barplot(data=all_data_gender[(filter1)], x="Metric", y="Value", hue="Gender", ax=ax[i],
                    errwidth=1, capsize=0.25, palette=palette_custom, hue_order=gender_order)
        
        ax[i].legend(title="Gender")
        ax[i].legend( loc="upper right")
        if i==0:
            ax[i].set(xlabel='')
            ax[i].set_ylabel("Rate",fontsize=20)
        if i==1:
            ax[i].set(xlabel='',ylabel='')
            ax[i].tick_params( labelleft=False)
        if i==2:
            ax[i].set_ylabel("Rate",fontsize=20)
            ax[i].set_xlabel("",fontsize=20)
        if i==3:
            ax[i].set(ylabel='')
            ax[i].set_xlabel("",fontsize=20)
            ax[i].tick_params( labelleft=False)
        if i==4:
            ax[i].set_ylabel('',fontsize=20)
            ax[i].set_xlabel("",fontsize=20)

213
    plt.savefig(Path.joinpath(pt.REPORTS_PLOTS_DIR, "XGBoost Gender Metrics.pdf"), dpi=300, bbox_inches="tight")
214
215
216
217
218
219
220
221
    
    # Calculate relation between male/female
    frame=all_data_gender
    newFrame=pd.DataFrame([],columns=["Model", "Metric", "Abs Difference", "Relation", "Relative difference (%)"])

    for i in list(frame["Model"].unique()):
        for j in list(frame["Metric"].unique()) :
            if j not in ["Mean_y_target","Mean_y_hat_prob"]:
222
223
224
225
                female_val = frame[(frame["Model"]==i) & (frame["Metric"]==j)
                                   & (frame["Gender"]=="Female")]["Value"].mean()
                male_val = frame[(frame["Model"]==i) & (frame["Metric"]==j)
                                 & (frame["Gender"]=="Male")]["Value"].mean()
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
                relation=female_val/male_val
                newFrame=newFrame.append({"Model":i,"Metric":j,"Relation":relation},ignore_index=True)

    # Plot relation
    plt.figure(figsize=(10,10))

    sns.stripplot(data=newFrame,x="Metric",y="Relation",hue="Model",
                  jitter=0.2,size=10,alpha=0.8,linewidth=1, palette={"XGBoost":"C0"})

    plt.hlines(y=1, xmin=-0.5, xmax=3.5, colors='red', linestyles='--', lw=1, label='Equal relation')
    plt.hlines(y=0.8, xmin=-0.5, xmax=3.5, colors='grey', linestyles='--', lw=1, label='Relation boundary')
    plt.hlines(y=1.25, xmin=-0.5, xmax=3.5, colors='grey', linestyles='--', lw=1)#, label='Relative difference=0.8')

    plt.legend( loc="upper left")
    plt.legend(loc=2, prop={'size': 15})
    plt.tick_params(axis='both', which='major', labelsize=20)
    plt.xlabel( "Metric",fontsize=20,labelpad=10)
    plt.ylabel( "Relation",fontsize=20,labelpad=10)
    plt.yticks(np.arange(0.6, 1.7, step=0.2))
    plt.ylim([0.5,1.7])

    plt.savefig(Path.joinpath(pt.REPORTS_PLOTS_DIR, "XGBoost Gender Metrics Relation.pdf"), dpi=300, bbox_inches = "tight")
    plt.show()
    
if __name__ == "__main__":  
    main()