test_model_balance.py 3.4 KB
Newer Older
1
#!/usr/bin/env python
2
from tools import data_loader
3
4
5
6
7
8
9
from utility.settings import load_settings
import csv
import paths as pt
from pathlib import Path
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_validate
from utility.metrics import compute_mean, compute_std
10
from sklearn.ensemble import RandomForestClassifier
11

12
13
14
15
def make_model(class_weight=None):
    return RandomForestClassifier(n_estimators=800,
                                  class_weight=class_weight,
                                  random_state=0)
16
17
18
19
20
21
22
23
24
25
26
27
28

def main():
    cases = ["Complete", "Compliance", "Fall", "Risk"]
    for case in cases:
        output_filename = f"{case} model balance.csv"
        header = ['clf', 'version', 'accuracy_mean', 'accuracy_std',
                  'precision_mean', 'precision_std', 'recall_mean',
                  'recall_std', 'roc_auc_mean', 'roc_auc_std',
                  'pr_auc_mean', 'pr_auc_std', 'f1_mean', 'f1_std']
        with open(Path.joinpath(pt.REPORTS_DIR, output_filename), 'w',
                  encoding='UTF8', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(header)
29

30
        if case == "Complete":
31
32
33
            settings = load_settings(pt.CONFIGS_DIR, "complete.yaml")
            dl = data_loader.CompleteDataLoader(pt.PROCESSED_DATA_DIR,
                                                "complete_emb.csv", settings).load_data()
34
35
            X, y = dl.get_data()
        elif case == "Compliance":
36
37
38
            settings = load_settings(pt.CONFIGS_DIR, "compliance.yaml")
            dl = data_loader.ComplianceDataLoader(pt.PROCESSED_DATA_DIR,
                                                  "compliance_emb.csv", settings).load_data()
39
            X, y = dl.get_data()
40
        elif case == "Fall":
41
42
43
            settings = load_settings(pt.CONFIGS_DIR, "fall.yaml")
            dl = data_loader.FallDataLoader(pt.PROCESSED_DATA_DIR,
                                            "fall_emb.csv", settings).load_data()
44
            X, y = dl.get_data()
45
46
47
48
49
50
        else:
            settings = load_settings(pt.CONFIGS_DIR, "risk.yaml")
            dl = data_loader.RiskDataLoader(pt.PROCESSED_DATA_DIR,
                                            "risk_emb.csv", settings).load_data()
            X, y = dl.get_data()

51
        X, y = dl.prepare_data()
52
        versions = ['NoCW', 'CW']
53
54
55
        metrics = ['accuracy', 'precision', 'recall', 'roc_auc', 'average_precision', 'f1']
        for version in versions:
            if version == "NoCW":
56
                model = make_model()
57
58
59
                kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)
                results = cross_validate(model, X, y, cv=kfold, scoring=metrics)
            else:
60
                model = make_model(class_weight="balanced")
61
                kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)
62
                results = cross_validate(model, X, y, cv=kfold, scoring=metrics)
63
64
65
66

            with open(Path.joinpath(pt.REPORTS_DIR, output_filename), 'a',
                      encoding='UTF8', newline='') as f:
                writer = csv.writer(f)
67
                data = ["RF", version]
68
69
70
71
72
73
74
75
                for metric in metrics:
                    mean = compute_mean(results[f'test_{metric}'])
                    std = compute_std(results[f'test_{metric}'])
                    data.extend((mean, std))
                writer.writerow(data)

if __name__ == '__main__':
    main()