train_xgboost_model.py 2.1 KB
Newer Older
1
2
3
#!/usr/bin/env python
import numpy as np
import pandas as pd
4
import paths as pt
5
from tools import file_reader, file_writer, data_loader
6
from utility import metrics
7
from sklearn.metrics import accuracy_score, precision_score
8
from sklearn.metrics import recall_score, roc_auc_score
9
10
11
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split
import xgboost as xgb
12
13
from pathlib import Path
import yaml
14

15
CASES = ["Complete", "Compliance", "Fall"]
16
DATASET_VERSION = 'emb'
17

18
def main():
19
20
    for case in CASES:
        if case == "Complete":
thecml's avatar
thecml committed
21
            with open(Path.joinpath(pt.CONFIGS_DIR, "complete.yaml"), 'r') as stream:
22
23
24
                settings = yaml.safe_load(stream)
            file_name = f'complete_{DATASET_VERSION}.csv'
            dl = data_loader.CompleteDataLoader(file_name, settings).load_data()
thecml's avatar
thecml committed
25
            X, y = dl.get_data()
26
        elif case == "Compliance":
thecml's avatar
thecml committed
27
            with open(Path.joinpath(pt.CONFIGS_DIR, "compliance.yaml"), 'r') as stream:
28
29
30
                settings = yaml.safe_load(stream)
            file_name = f'compliance_{DATASET_VERSION}.csv'
            dl = data_loader.ComplianceDataLoader(file_name, settings).load_data()
thecml's avatar
thecml committed
31
            X, y = dl.get_data()
32
        else:
thecml's avatar
thecml committed
33
            with open(Path.joinpath(pt.CONFIGS_DIR, "fall.yaml"), 'r') as stream:
34
                settings = yaml.safe_load(stream)
35
36
            file_name = f'fall_{DATASET_VERSION}.csv'
            dl = data_loader.FallDataLoader(file_name, settings).load_data()
thecml's avatar
thecml committed
37
            X, y = dl.get_data()
38
        
39
40
        neg, pos = np.bincount(y)
        scale_pos_weight = neg / pos
41
        params = {"n_estimators": 400,
42
43
44
45
                  "learning_rate": 0.1,
                  "scale_pos_weight": scale_pos_weight,
                  "objective": "binary:logistic",
                  "use_label_encoder": False,
thecml's avatar
thecml committed
46
47
                  "eval_metric": "logloss",
                  "random_state": 0}
48
        
49
        model = xgb.XGBClassifier(**params)
50
        model.fit(X, y)
thecml's avatar
thecml committed
51
        
52
        file_writer.write_joblib(model, pt.MODELS_DIR, f'{case.lower()}_xgb.joblib')
53
54
    
if __name__ == '__main__':
55
    main()