tune_alarm_rsf_wb.py 4.12 KB
Newer Older
thecml's avatar
thecml committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from utility.settings import load_settings
from sksurv.ensemble import RandomSurvivalForest
from sksurv.metrics import (concordance_index_censored,
                            concordance_index_ipcw,
                            integrated_brier_score)
from sklearn.model_selection import KFold
from tools import data_loader, preprocessor
import paths as pt
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd

os.environ["WANDB_SILENT"] = "true"
import wandb

sweep_config = {
    "method": "random", # try grid or random
    "metric": {
      "name": "c_harrell",
      "goal": "maximize"
    },
    "parameters": {
        "n_estimators": {
            "values": [50, 100, 200, 400, 600, 800, 1000]
        },
        "max_depth": {
            "values": [int(x) for x in np.linspace(1, 32, 32, endpoint=True)]
        },
        "min_samples_split": {
            "values": [float(x) for x in np.linspace(0.1, 0.9, 10, endpoint=True)]
        },
        "min_samples_leaf": {
            "values": [float(x) for x in np.linspace(0.1, 0.5, 5, endpoint=True)]
        },
        "max_features": {
            "values": [None, 'auto', 'sqrt', 'log2']
        },
    }
}

def main():
    sweep_id = wandb.sweep(sweep_config,
                           project="air-alarm-rsf")
    wandb.agent(sweep_id, train_model, count=5)

def train_model():
    config_defaults = {
        'n_estimators': [100],
        'max_depth' : [None],
        'min_samples_split': [2],
        'min_samples_leaf': [1],
        'max_features': [None],
        "seed": 0,
        "test_size": 0.25,
    }

    # Initialize a new wandb run
    wandb.init(config=config_defaults)

    # Config is a variable that holds and saves hyperparameters and inputs
    config = wandb.config

    # Load data
    data_settings = load_settings(pt.CONFIGS_DIR, "data.yaml")
    target_settings = load_settings(pt.CONFIGS_DIR, "alarm.yaml")
    dl = data_loader.AlarmDataLoader(pt.PROCESSED_DATA_DIR,
                                     "alarm_data.pkl",
                                     target_settings).load_data()
    X, y = dl.get_data()

    # Encode X
    ats_resolution = data_settings['ats_resolution']
    ats_cols = [str(i)+'Ats' for i in range(1, ats_resolution+1)]
    X_enc = preprocessor.one_hot_encode(X, ats_cols)
    X = pd.concat([X.drop(ats_cols, axis=1), X_enc], axis=1)

    # Make model
    model = RandomSurvivalForest(n_estimators=config.n_estimators,
                                 max_depth=config.max_depth,
                                 min_samples_split=config.min_samples_split,
                                 min_samples_leaf=config.min_samples_leaf,
                                 max_features=config.max_features,
                                 random_state=0)

    # Make CV
    kf = KFold(n_splits=5, shuffle=True, random_state=0)
    c_index_harells = list()
    c_index_unos = list()
    brier_scores = list()
    for train, test in kf.split(X, y):
        model.fit(X.iloc[train], y[train])
        prediction = model.predict(X.iloc[test])

        c_harrell = concordance_index_censored(y[test]["Status"],
                                               y[test]["Days_to_alarm"],
                                               prediction)
        c_uno = concordance_index_ipcw(y[train], y[test], prediction)
        lower, upper = np.percentile(y["Days_to_alarm"], [10, 90])
        alarm_times = np.arange(lower, upper+1)
        surv_prob = np.row_stack([fn(alarm_times)
                                  for fn in model.predict_survival_function(X.iloc[test])])
        brier_score = integrated_brier_score(y[train], y[test],
                                             surv_prob, alarm_times)

        c_index_harells.append(c_harrell[0])
        c_index_unos.append(c_uno[0])
        brier_scores.append(brier_score)

    c_index_harell_mean = np.mean(c_index_harells)
    c_index_uno_mean = np.mean(c_index_unos)
    brier_score_mean = np.mean(brier_scores)

    # Log to wandb
    wandb.log({"c_harrell": c_index_harell_mean})
    wandb.log({"c_uno": c_index_uno_mean})
    wandb.log({"brier_score": brier_score_mean})

if __name__ == "__main__":
    main()