test_model_survival.py 1.75 KB
Newer Older
thecml's avatar
thecml committed
1
2
#!/usr/bin/env python
import paths as pt
thecml's avatar
thecml committed
3
from tools import data_loader, file_reader
thecml's avatar
thecml committed
4
5
from utility.settings import load_settings
from sksurv.ensemble import RandomSurvivalForest
thecml's avatar
thecml committed
6
from sklearn.model_selection import KFold
thecml's avatar
thecml committed
7
from sksurv.metrics import concordance_index_censored
thecml's avatar
thecml committed
8
9
10
from io import StringIO
from pathlib import Path
import shutil
thecml's avatar
thecml committed
11
12
13
14
15

def main():
    data_settings = load_settings(pt.CONFIGS_DIR, "data.yaml")
    target_settings = load_settings(pt.CONFIGS_DIR, "alarm.yaml")
    ats_resolution = data_settings['ats_resolution']
16

thecml's avatar
thecml committed
17
18
19
20
    dl = data_loader.AlarmDataLoader(pt.PROCESSED_DATA_DIR,
                                     "alarm_data.pkl",
                                     target_settings).load_data()
    X, y = dl.get_data()
21

thecml's avatar
thecml committed
22
23
24
25
26
27
28
29
30
    ats = {str(i)+'Ats':str for i in range(1, ats_resolution+1)}
    infile = StringIO()
    file_path = pt.PROCESSED_DATA_DIR
    file_name = "alarm_emb.csv"
    with open(Path.joinpath(file_path, file_name), 'r') as fd:
        shutil.copyfileobj(fd, infile)
        infile.seek(0)
        emb_file = file_reader.read_csv(infile, converters=ats)
    X = emb_file
31

thecml's avatar
thecml committed
32
33
34
    model = RandomSurvivalForest(n_estimators=200,
                                 max_depth=3,
                                 random_state=0)
thecml's avatar
thecml committed
35
36
    kf = KFold(n_splits=5, shuffle=True, random_state=0)
    c_index_scores = list()
37

38
    for train, test in kf.split(X, y): #8 min execution time
thecml's avatar
thecml committed
39
40
41
42
43
44
        model.fit(X.iloc[train], y[train])
        prediction = model.predict(X.iloc[test])
        c_index = concordance_index_censored(y[test]["Status"],
                                             y[test]["Days_to_alarm"],
                                             prediction)
        c_index_scores.append(c_index[0])
45

thecml's avatar
thecml committed
46
    print(c_index_scores)
47

thecml's avatar
thecml committed
48
if __name__ == '__main__':
49
    main()