train_xgboost_model.py 2.54 KB
Newer Older
1
2
3
#!/usr/bin/env python
import numpy as np
import pandas as pd
4
import paths as pt
5
from tools import file_reader, file_writer, data_loader
6
from utility.settings import load_settings
7
from sklearn.metrics import accuracy_score, precision_score
8
from sklearn.metrics import recall_score, roc_auc_score
9
10
11
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split
import xgboost as xgb
12
from pathlib import Path
13
14
from io import BytesIO
import shutil
15

16
CASES = ["Complete", "Compliance", "Fall"]
17
DATASET_VERSION = 'emb'
18

19
def main():
20
21
    for case in CASES:
        if case == "Complete":
22
            settings = load_settings(pt.CONFIGS_DIR, "complete.yaml")
23
            file_name = f'complete_{DATASET_VERSION}.csv'
24
25
26
27
            dl = data_loader.CompleteDataLoader(pt.PROCESSED_DATA_DIR,
                                                file_name,
                                                settings).load_data()
            X, y = dl.prepare_data()
28
        elif case == "Compliance":
29
30
31
32
33
34
            settings = load_settings(pt.CONFIGS_DIR, "compliance.yaml")
            file_name = f'compliance_{DATASET_VERSION}.csv'            
            dl = data_loader.ComplianceDataLoader(pt.PROCESSED_DATA_DIR,
                                                  file_name,
                                                  settings).load_data()
            X, y = dl.prepare_data()
35
        else:
36
            settings = load_settings(pt.CONFIGS_DIR, "fall.yaml")
37
            file_name = f'fall_{DATASET_VERSION}.csv'
38
39
40
41
            dl = data_loader.FallDataLoader(pt.PROCESSED_DATA_DIR,
                                            file_name,
                                            settings).load_data()
            X, y = dl.prepare_data()
42
        
43
44
        neg, pos = np.bincount(y)
        scale_pos_weight = neg / pos
45
        params = {"n_estimators": 400,
46
47
48
49
                  "learning_rate": 0.1,
                  "scale_pos_weight": scale_pos_weight,
                  "objective": "binary:logistic",
                  "use_label_encoder": False,
thecml's avatar
thecml committed
50
51
                  "eval_metric": "logloss",
                  "random_state": 0}
52
        
53
        model = xgb.XGBClassifier(**params)
54
        model.fit(X, y)
55
56
57
58
59
60
61
62
63
            
        file_path = pt.MODELS_DIR
        file_name = f'{case.lower()}_xgb.joblib'
        with open(Path.joinpath(file_path, file_name), 'wb') as fd:
            outfile = BytesIO()
            file_writer.write_joblib(model, outfile)
            outfile.seek(0)
            shutil.copyfileobj(outfile, fd)
            
64
if __name__ == '__main__':
65
    main()