tune_alarm_boost_wb.py 4.6 KB
Newer Older
thecml's avatar
thecml committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from utility.settings import load_settings
from sksurv.ensemble import GradientBoostingSurvivalAnalysis
from sksurv.metrics import (concordance_index_censored,
                            concordance_index_ipcw,
                            integrated_brier_score)
from sklearn.model_selection import KFold
from tools import data_loader, preprocessor
import paths as pt
import numpy as np
import pandas as pd
import os

os.environ["WANDB_SILENT"] = "true"
import wandb

sweep_config = {
    "method": "random", # try grid or random
    "metric": {
      "name": "c_harrell",
      "goal": "maximize"
    },
    "parameters": {
        "n_estimators": {
            "values": [50, 100, 200, 400, 600, 800, 1000]
        },
        "learning_rate": {
            "values": [0.1, 0.5, 1.0]
        },
        "max_depth": {
            "values": [int(x) for x in np.linspace(1, 18, 15, endpoint=True)]
        },
        "loss": {
            "values": ['coxph']
        },
        "min_samples_split": {
            "values": [int(x) for x in np.linspace(2, 10, 10, endpoint=True)]
        },
        "max_features": {
            "values": [None, "auto", "sqrt", "log2"]
        },
        "dropout_rate": {
            "values": [float(x) for x in np.linspace(0.0, 0.9, 10, endpoint=True)]
        },
        "subsample": {
            "values": [float(x) for x in np.linspace(0.1, 1.0, 10, endpoint=True)]
        }
    }
}

def main():
    sweep_id = wandb.sweep(sweep_config,
                           project="air-alarm-boost")
    wandb.agent(sweep_id, train_model, count=5)

def train_model():
    config_defaults = {
        'n_estimators': 100,
        'learning_rate': 0.1,
        'max_depth': 3,
        'loss': 'coxph',
        'min_samples_split': 2,
        'max_features': None,
        'dropout_rate': 0.0,
        'subsample': 1.0,
        'seed': 0,
        'test_size': 0.25,
    }

    # Initialize a new wandb run
    wandb.init(config=config_defaults)

    # Config is a variable that holds and saves hyperparameters and inputs
    config = wandb.config

    # Load data
    data_settings = load_settings(pt.CONFIGS_DIR, "data.yaml")
    target_settings = load_settings(pt.CONFIGS_DIR, "alarm.yaml")
    dl = data_loader.AlarmDataLoader(pt.PROCESSED_DATA_DIR,
                                     "alarm_data.pkl",
                                     target_settings).load_data()
    X, y = dl.get_data()

    # Encode X
    ats_resolution = data_settings['ats_resolution']
    ats_cols = [str(i)+'Ats' for i in range(1, ats_resolution+1)]
    X_enc = preprocessor.one_hot_encode(X, ats_cols)
    X = pd.concat([X.drop(ats_cols, axis=1), X_enc], axis=1)

    # Make model
    model = GradientBoostingSurvivalAnalysis(n_estimators=config.n_estimators,
                                             learning_rate=config.learning_rate,
                                             max_depth=config.max_depth,
                                             loss=config.loss,
                                             min_samples_split=config.min_samples_split,
                                             max_features=config.max_features,
                                             dropout_rate=config.dropout_rate,
                                             random_state=0)

    # Make CV
    kf = KFold(n_splits=5, shuffle=True, random_state=0)
    c_index_harells = list()
    c_index_unos = list()
    brier_scores = list()
    for train, test in kf.split(X, y):
        model.fit(X.iloc[train], y[train])
        prediction = model.predict(X.iloc[test])

        c_harrell = concordance_index_censored(y[test]["Status"],
                                               y[test]["Days_to_alarm"],
                                               prediction)
        c_uno = concordance_index_ipcw(y[train], y[test], prediction)
        lower, upper = np.percentile(y["Days_to_alarm"], [10, 90])
        alarm_times = np.arange(lower, upper+1)
        surv_prob = np.row_stack([fn(alarm_times)
                                  for fn in model.predict_survival_function(X.iloc[test])])
        brier_score = integrated_brier_score(y[train], y[test],
                                             surv_prob, alarm_times)

        c_index_harells.append(c_harrell[0])
        c_index_unos.append(c_uno[0])
        brier_scores.append(brier_score)

    c_index_harell_mean = np.mean(c_index_harells)
    c_index_uno_mean = np.mean(c_index_unos)
    brier_score_mean = np.mean(brier_scores)

    # Log to wandb
    wandb.log({"c_harrell": c_index_harell_mean})
    wandb.log({"c_uno": c_index_uno_mean})
    wandb.log({"brier_score": brier_score_mean})

if __name__ == "__main__":
    main()