make_dataset_full.py 3.41 KB
Newer Older
1
#!/usr/bin/env python
2
import paths as pt
3
from tools import file_reader, file_writer, labeler
4
5
from tools import preprocessor
import pandas as pd
6
7
import yaml
from pathlib import Path
8

9
10
def main(ats_resolution: int = None):
    clusters = file_reader.read_csv(pt.INTERIM_DATA_DIR, 'cl.csv',
11
                                    converters={'CitizenId': str, 'Cluster': int})
12
    screenings = file_reader.read_csv(pt.INTERIM_DATA_DIR, 'screenings.csv',
13
                                      converters={'CitizenId': str})
14
            
15
    for label_name in ['Complete', 'Compliance', 'Alarm', 'Fall']:
16
        # Load settings for target
17
        with open(Path.joinpath(pt.CONFIGS_DIR, f'{label_name.lower()}_emb.yaml'), 'r') as stream:
18
            settings = yaml.safe_load(stream)
19
            
20
21
        if ats_resolution == None:
            ats_resolution = settings['ats_resolution']
22
        if label_name == "Fall":
23
24
25
            ex_resolution = settings['ex_resolution']
        features = settings['features']
        
26
27
        df = screenings.copy()
        df['Cluster'] = clusters['Cluster']
28
        
29
        # Encode target label
30
31
32
33
34
35
36
        if label_name == 'Complete':
            df = labeler.make_complete_label(df, settings)
        elif label_name == 'Compliance':
            df = labeler.make_compliance_label(df, settings)
        elif label_name == 'Alarm':
            df = labeler.make_alarm_label(df, settings)
            df['Ats'] = df['Ats'].apply(lambda x: x.replace('222718', '0'))
37
        else:
38
39
40
41
42
43
            df = labeler.make_fall_label(df, settings)
            
        # Split cat columns by ATS resolution
        df = preprocessor.split_cat_columns(df, col_to_split='Ats', tag='Ats', resolution=ats_resolution)
        if label_name == "Fall":
            df = preprocessor.split_cat_columns(df, col_to_split='Ex', tag='Ex', resolution=ex_resolution)
44
        
45
46
47
48
        # One-hot-encode gender variable
        object_cols = ['Gender']
        df_enc = preprocessor.one_hot_encode(df, object_cols)
        df = pd.concat([df.drop(object_cols, axis=1), df_enc], axis=1)
thecml's avatar
thecml committed
49
50
51
        df['Gender_Female'] = df['Gender_Female'].astype(int)
        df['Gender_Male'] = df['Gender_Male'].astype(int)
        
52
        # Concat dataframe in proper order
53
        if label_name in ["Complete", "Compliance", "Alarm"]:
54
            ats_cols = df.filter(regex='Ats', axis=1)
55
            df = pd.concat([df[features], ats_cols, df[[label_name]]], axis=1)
56
57
        else:
            ats_ex_cols = df.filter(regex='Ats|Ex', axis=1)
58
            df = pd.concat([df[features], ats_ex_cols, df[[label_name]]], axis=1)
59

60
        if settings['use_real_ats_names']:
61
            if label_name in ["Complete", "Compliance", "Alarm"]:
62
63
64
65
66
67
68
69
70
71
72
                ats = file_reader.read_csv(pt.REFERENCES_DIR, 'ats.csv',
                                        converters={'ats_id': str})
                df = preprocessor.replace_cat_values(df, ats)
            else:
                ats = file_reader.read_csv(pt.REFERENCES_DIR, 'ats.csv',
                                           converters={'ats_id': str})
                ex = file_reader.read_csv(pt.REFERENCES_DIR, 'ex.csv',
                                          converters={'ex_id': str})
                df = preprocessor.replace_cat_values(df, ats)
                df = preprocessor.replace_cat_values(df, ex)
                
73
        file_writer.write_csv(df, pt.PROCESSED_DATA_DIR, f'{label_name.lower()}.csv')
74
        
75
76
if __name__ == "__main__":
    main()