train_xgboost_model.py 2.67 KB
Newer Older
1
2
#!/usr/bin/env python
import numpy as np
3
import paths as pt
thecml's avatar
thecml committed
4
from tools import file_writer, data_loader
5
from utility.settings import load_settings
6
import xgboost as xgb
7
from pathlib import Path
8
9
from io import BytesIO
import shutil
10

11
CASES = ["Complete", "Compliance", "Fall", "Risk"]
12
DATASET_VERSION = 'emb'
13

14
def main():
15
16
    for case in CASES:
        if case == "Complete":
17
            settings = load_settings(pt.CONFIGS_DIR, "complete.yaml")
18
            file_name = f'complete_{DATASET_VERSION}.csv'
19
20
21
22
            dl = data_loader.CompleteDataLoader(pt.PROCESSED_DATA_DIR,
                                                file_name,
                                                settings).load_data()
            X, y = dl.prepare_data()
23
        elif case == "Compliance":
24
25
26
27
28
29
            settings = load_settings(pt.CONFIGS_DIR, "compliance.yaml")
            file_name = f'compliance_{DATASET_VERSION}.csv'            
            dl = data_loader.ComplianceDataLoader(pt.PROCESSED_DATA_DIR,
                                                  file_name,
                                                  settings).load_data()
            X, y = dl.prepare_data()
30
        elif case == "Fall":
31
            settings = load_settings(pt.CONFIGS_DIR, "fall.yaml")
32
            file_name = f'fall_{DATASET_VERSION}.csv'
33
            dl = data_loader.FallDataLoader(pt.PROCESSED_DATA_DIR,
34
35
36
37
                                            file_name,
                                            settings).load_data()
            X, y = dl.prepare_data()
        else:
thecml's avatar
thecml committed
38
            settings = load_settings(pt.CONFIGS_DIR, "risk.yaml")
39
40
            file_name = f'risk_{DATASET_VERSION}.csv'
            dl = data_loader.RiskDataLoader(pt.PROCESSED_DATA_DIR,
41
42
43
                                            file_name,
                                            settings).load_data()
            X, y = dl.prepare_data()
44
        
45
46
        neg, pos = np.bincount(y)
        scale_pos_weight = neg / pos
47
        params = {"n_estimators": 400,
48
49
50
51
                  "learning_rate": 0.1,
                  "scale_pos_weight": scale_pos_weight,
                  "objective": "binary:logistic",
                  "use_label_encoder": False,
thecml's avatar
thecml committed
52
53
                  "eval_metric": "logloss",
                  "random_state": 0}
54
        
55
        model = xgb.XGBClassifier(**params)
56
        model.fit(X, y)
57
58
59
60
61
62
63
64
65
            
        file_path = pt.MODELS_DIR
        file_name = f'{case.lower()}_xgb.joblib'
        with open(Path.joinpath(file_path, file_name), 'wb') as fd:
            outfile = BytesIO()
            file_writer.write_joblib(model, outfile)
            outfile.seek(0)
            shutil.copyfileobj(outfile, fd)
            
66
if __name__ == '__main__':
67
    main()