Commit ec99f3f1 authored by Christian Marius Lillelund's avatar Christian Marius Lillelund
Browse files

reworked the fall case

parent b6406ffa
Pipeline #87503 failed with stage
in 2 minutes and 53 seconds
......@@ -2,16 +2,16 @@
# Dataset Stuff -------------------------------------------------
#
target_name: "Risk"
model_path: models/risk/embeddings
risk_period_months: 6
target_name: "Alarm"
model_path: models/alarm/embeddings
threshold_weeks: 8
threshold_training: 10
# Embedding Hyperparams --------------------------------------
train_ratio: 0.8
batch_size: 32
num_epochs_ats: 10
num_epochs_ex: 5
num_epochs: 10
verbose: False
network_layers: [128]
optimizer: "Adam"
......@@ -19,30 +19,18 @@ optimizer: "Adam"
# Settings for data loader -------------------------------------------------
#
features_to_normalize: ['BirthYear', 'Cluster', 'LoanPeriod', 'NumberSplit',
'NumberScreening', 'NumberWeeks', 'MeanEvaluation',
'NumberFalls', 'NumberTraining', 'NumberTrainingWeek',
'TimeBetweenTraining', 'NumberWeeksNoTraining',
'Needs', 'Physics', 'NumberAts', 'NumberEx']
features_to_scale: ['Gender_Male', 'Gender_Female', 'BirthYear', 'Cluster',
'LoanPeriod', 'NumberSplit', 'NumberScreening',
'NumberWeeks', 'MeanEvaluation',
'NumberFalls', 'NumberTraining', 'NumberTrainingWeek',
'TimeBetweenTraining', 'NumberWeeksNoTraining',
'Needs', 'Physics', 'NumberAts', 'NumberEx']
features_to_normalize: ['BirthYear', 'Cluster', 'LoanPeriod', 'NumberAts']
features_to_scale: ['Gender_Male', 'Gender_Female', 'BirthYear',
'Cluster', 'LoanPeriod', 'NumberAts']
# Settings for data script -------------------------------------------------
#
features: ['Gender_Male', 'Gender_Female', 'BirthYear', 'Cluster',
'LoanPeriod', 'NumberSplit', 'NumberScreening', 'NumberWeeks',
'MeanEvaluation', 'NumberFalls', 'NumberTraining',
'NumberTrainingWeek', 'TimeBetweenTraining',
'NumberWeeksNoTraining', 'Needs', 'Physics']
features: ['Gender_Male', 'Gender_Female', 'BirthYear',
'Cluster', 'LoanPeriod']
# Settings for dataset -------------------------------------------------
#
use_real_ats_names: False
ats_resolution: 10
ex_resolution: 9
\ No newline at end of file
ats_resolution: 50
\ No newline at end of file
......@@ -33,4 +33,4 @@ features: ['Gender_Male', 'Gender_Female', 'BirthYear',
#
use_real_ats_names: False
ats_resolution: 10
\ No newline at end of file
ats_resolution: 50
\ No newline at end of file
......@@ -33,4 +33,4 @@ features: ['Gender_Male', 'Gender_Female', 'BirthYear',
#
use_real_ats_names: False
ats_resolution: 10
\ No newline at end of file
ats_resolution: 50
\ No newline at end of file
......@@ -4,14 +4,14 @@
target_name: "Fall"
model_path: models/fall/embeddings
threshold_weeks: 8
threshold_training: 10
fall_period_months: 6
# Embedding Hyperparams --------------------------------------
train_ratio: 0.8
batch_size: 32
num_epochs: 10
num_epochs_ats: 10
num_epochs_ex: 5
verbose: False
network_layers: [128]
optimizer: "Adam"
......@@ -19,18 +19,30 @@ optimizer: "Adam"
# Settings for data loader -------------------------------------------------
#
features_to_normalize: ['BirthYear', 'Cluster', 'LoanPeriod', 'NumberAts']
features_to_scale: ['Gender_Male', 'Gender_Female', 'BirthYear',
'Cluster', 'LoanPeriod', 'NumberAts']
features_to_normalize: ['BirthYear', 'Cluster', 'LoanPeriod', 'NumberSplit',
'NumberScreening', 'NumberWeeks', 'MeanEvaluation',
'NumberTraining', 'NumberTrainingWeek',
'TimeBetweenTraining', 'NumberWeeksNoTraining',
'Needs', 'Physics', 'NumberAts', 'NumberEx']
features_to_scale: ['Gender_Male', 'Gender_Female', 'BirthYear', 'Cluster',
'LoanPeriod', 'NumberSplit', 'NumberScreening',
'NumberWeeks', 'MeanEvaluation',
'NumberTraining', 'NumberTrainingWeek',
'TimeBetweenTraining', 'NumberWeeksNoTraining',
'Needs', 'Physics', 'NumberAts', 'NumberEx']
# Settings for data script -------------------------------------------------
#
features: ['Gender_Male', 'Gender_Female', 'BirthYear',
'Cluster', 'LoanPeriod']
features: ['Gender_Male', 'Gender_Female', 'BirthYear', 'Cluster',
'LoanPeriod', 'NumberSplit', 'NumberScreening', 'NumberWeeks',
'MeanEvaluation', 'NumberTraining', 'NumberTrainingWeek',
'TimeBetweenTraining', 'NumberWeeksNoTraining', 'Needs', 'Physics']
# Settings for dataset -------------------------------------------------
#
use_real_ats_names: False
ats_resolution: 10
\ No newline at end of file
ats_resolution: 50
ex_resolution: 9
risk_period_months: 6
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -58,11 +58,11 @@ def main():
X, y = dl.get_data()
elif case == "Fall":
settings = load_settings("fall_emb.yaml")
dl = data_loader.FallDataLoader("fall_emb.csv", settings).load_data()
dl = data_loader.AlarmDataLoader("fall_emb.csv", settings).load_data()
X, y = dl.get_data()
else:
settings = load_settings("risk_emb.yaml")
dl = data_loader.RiskDataLoader("risk_emb.csv", settings).load_data()
dl = data_loader.FallDataLoader("risk_emb.csv", settings).load_data()
X, y = dl.get_data()
X, y = dl.prepare_data()
......
......@@ -34,11 +34,11 @@ def load_data_embedded(case, settings):
elif case == "Compliance":
dl = data_loader.ComplianceDataLoader("compliance_emb.csv", settings).load_data()
X, y = dl.get_data()
elif case == "Fall":
dl = data_loader.FallDataLoader("fall_emb.csv", settings).load_data()
elif case == "Alarm":
dl = data_loader.AlarmDataLoader("alarm_emb.csv", settings).load_data()
X, y = dl.get_data()
else:
dl = data_loader.RiskDataLoader("risk_emb.csv", settings).load_data()
dl = data_loader.FallDataLoader("fall_emb.csv", settings).load_data()
X, y = dl.get_data()
return X, y
......@@ -49,26 +49,26 @@ def load_data_count(case, settings):
elif case == "Compliance":
dl = data_loader.ComplianceDataLoader("compliance_count.csv", settings).load_data()
X, y = dl.get_data()
elif case == "Fall":
dl = data_loader.FallDataLoader("fall_count.csv", settings).load_data()
elif case == "Alarm":
dl = data_loader.AlarmDataLoader("alarm_count.csv", settings).load_data()
X, y = dl.get_data()
else:
dl = data_loader.RiskDataLoader("risk_count.csv", settings).load_data()
dl = data_loader.FallDataLoader("fall_count.csv", settings).load_data()
X, y = dl.get_data()
return X, y
def load_data_onehot(case, settings):
def load_data_ohe(case, settings):
if case == "Complete":
dl = data_loader.CompleteDataLoader("complete_ohe.csv", settings).load_data()
X, y = dl.get_data()
elif case == "Compliance":
dl = data_loader.ComplianceDataLoader("compliance_ohe.csv", settings).load_data()
X, y = dl.get_data()
elif case == "Fall":
dl = data_loader.FallDataLoader("fall_ohe.csv", settings).load_data()
elif case == "Alarm":
dl = data_loader.AlarmDataLoader("alarm_ohe.csv", settings).load_data()
X, y = dl.get_data()
else:
dl = data_loader.RiskDataLoader("risk_ohe.csv", settings).load_data()
dl = data_loader.FallDataLoader("fall_ohe.csv", settings).load_data()
X, y = dl.get_data()
return X, y
......@@ -77,7 +77,8 @@ def main():
clf_names = ['KNN', 'SVM', 'LR', 'XGB', 'RF', 'MLP']
num_clfs = len(clf_names)
metrics = ['accuracy', 'precision', 'recall', 'roc_auc', 'average_precision', 'f1']
cases = ["Complete", "Compliance", "Fall", "Risk"]
#cases = ["Complete", "Compliance", "Fall", "Risk"]
cases = ['Alarm']
for case in cases:
settings = load_settings(f'{case.lower()}_emb.yaml')
output_filename = f"{case} model baseline.csv"
......@@ -100,7 +101,7 @@ def main():
elif version == "Counts":
X, y = load_data_count(case, settings)
else:
X, y = load_data_onehot(case, settings)
X, y = load_data_ohe(case, settings)
X, y = prepare_data(X, y, settings)
results = train_clf(X, y, version, output_filename, metrics, num_iter)
......
......@@ -11,10 +11,11 @@ import shap
from typing import List
from utility.settings import load_settings
NUM_ITERATIONS = 2
NUM_ITERATIONS = 1
def main():
cases = ["Complete", "Compliance", "Fall", "Risk"]
#cases = ["Complete", "Compliance", "Alarm", "Fall"]
cases = ['Alarm']
for case in cases:
if case == "Complete":
settings = load_settings("complete_emb.yaml")
......@@ -24,13 +25,13 @@ def main():
settings = load_settings("compliance_emb.yaml")
dl = data_loader.ComplianceDataLoader("compliance_emb.csv", settings).load_data()
X, y = dl.get_data()
elif case == "Fall":
settings = load_settings("fall_emb.yaml")
dl = data_loader.FallDataLoader("fall_emb.csv", settings).load_data()
elif case == "Alarm":
settings = load_settings("alarm_emb.yaml")
dl = data_loader.AlarmDataLoader("alarm_count.csv", settings).load_data()
X, y = dl.get_data()
else:
settings = load_settings("risk_emb.yaml")
dl = data_loader.RiskDataLoader("risk_emb.csv", settings).load_data()
settings = load_settings("fall_emb.yaml")
dl = data_loader.FallDataLoader("fall_emb.csv", settings).load_data()
X, y = dl.get_data()
features = dl.get_features()
......@@ -64,13 +65,12 @@ def get_best_shap_features(X: np.ndarray, y: np.ndarray,
neg, pos = np.bincount(y)
scale_pos_weight = neg / pos
model = xgb.XGBClassifier(n_estimators=200,
learning_rate=0.07,
model = xgb.XGBClassifier(n_estimators=400,
learning_rate=0.1,
objective='binary:logistic',
scale_pos_weight=scale_pos_weight,
eval_metric='logloss',
use_label_encoder=False,
n_jobs=-1,
random_state=seed)
acc_score_list = list()
......
......@@ -23,7 +23,7 @@ def main():
# Load the data
file_name = "fall_emb.csv"
dl = data_loader.FallDataLoader(file_name, settings).load_data()
dl = data_loader.AlarmDataLoader(file_name, settings).load_data()
X, y = dl.get_data()
X = X.drop(['Gender_Female'], axis=1)
......
......@@ -93,11 +93,11 @@ def main():
X, y = dl.get_data()
elif case == "Fall":
settings = load_settings("fall_emb.yaml")
dl = data_loader.FallDataLoader("fall_emb.csv", settings).load_data()
dl = data_loader.AlarmDataLoader("fall_emb.csv", settings).load_data()
X, y = dl.get_data()
else:
settings = load_settings("risk_emb.yaml")
dl = data_loader.RiskDataLoader("risk_emb.csv", settings).load_data()
dl = data_loader.FallDataLoader("risk_emb.csv", settings).load_data()
X, y = dl.get_data()
emb_cols = X.filter(regex='((\d+)[Ats|Ex])\w+', axis=1)
......
......@@ -11,13 +11,14 @@ import yaml
def main():
make_complete_count()
make_compliance_count()
make_alarm_count()
make_fall_count()
make_risk_count()
def make_complete_count():
with open(Path.joinpath(pt.CONFIGS_DIR, "complete_emb.yaml"), 'r') as stream:
settings = yaml.safe_load(stream)
case = 'Complete'
label_name = 'Complete'
ats = {str(i)+'Ats':str for i in range(1, settings['ats_resolution']+1)}
df = file_reader.read_csv(pt.PROCESSED_DATA_DIR,
f'complete.csv',
......@@ -31,14 +32,15 @@ def make_complete_count():
df_ats = df_ats.drop(['Ats_0'], axis=1)
df = df.drop(cols_ats, axis=1)
df = pd.concat([df.drop(case, axis=1), df_ats, df[[case]]], axis=1)
df = pd.concat([df.drop(label_name, axis=1), df_ats, df[[label_name]]], axis=1)
file_writer.write_csv(df, pt.PROCESSED_DATA_DIR, 'complete_count.csv')
file_writer.write_csv(df, pt.PROCESSED_DATA_DIR, f'{label_name.lower()}_count.csv')
def make_compliance_count():
with open(Path.joinpath(pt.CONFIGS_DIR, "compliance_emb.yaml"), 'r') as stream:
settings = yaml.safe_load(stream)
case = 'Compliance'
label_name = 'Compliance'
ats = {str(i)+'Ats':str for i in range(1, settings['ats_resolution']+1)}
df = file_reader.read_csv(pt.PROCESSED_DATA_DIR,
f'compliance.csv',
......@@ -51,18 +53,17 @@ def make_compliance_count():
df_ats = df_ats.drop(['Ats_0'], axis=1)
df = df.drop(cols_ats, axis=1)
df = pd.concat([df.drop(case, axis=1), df_ats, df[[case]]], axis=1)
df = pd.concat([df.drop(label_name, axis=1), df_ats, df[[label_name]]], axis=1)
file_writer.write_csv(df, pt.PROCESSED_DATA_DIR, 'compliance_count.csv')
file_writer.write_csv(df, pt.PROCESSED_DATA_DIR, f'{label_name.lower()}_count.csv')
def make_fall_count():
with open(Path.joinpath(pt.CONFIGS_DIR, "fall_emb.yaml"), 'r') as stream:
def make_alarm_count():
with open(Path.joinpath(pt.CONFIGS_DIR, "alarm_emb.yaml"), 'r') as stream:
settings = yaml.safe_load(stream)
case = 'Fall'
label_name = 'Alarm'
ats = {str(i)+'Ats':str for i in range(1, settings['ats_resolution']+1)}
df = file_reader.read_csv(pt.PROCESSED_DATA_DIR,
f'fall.csv',
converters=ats)
df = file_reader.read_csv(pt.PROCESSED_DATA_DIR, f'alarm.csv', converters=ats)
cols_ats = [str(i)+'Ats' for i in range(1, settings['ats_resolution']+1)]
unique_ats = [df[f'{i}Ats'].unique() for i in range(1, settings['ats_resolution']+1)]
......@@ -72,22 +73,21 @@ def make_fall_count():
df_ats = df_ats.drop(['Ats_0'], axis=1)
df = df.drop(cols_ats, axis=1)
df = pd.concat([df.drop(case, axis=1), df_ats, df[[case]]], axis=1)
df = pd.concat([df.drop(label_name, axis=1), df_ats, df[[label_name]]], axis=1)
file_writer.write_csv(df, pt.PROCESSED_DATA_DIR, 'fall_count.csv')
file_writer.write_csv(df, pt.PROCESSED_DATA_DIR, f'{label_name.lower()}_count.csv')
def make_risk_count():
with open(Path.joinpath(pt.CONFIGS_DIR, "risk_emb.yaml"), 'r') as stream:
def make_fall_count():
with open(Path.joinpath(pt.CONFIGS_DIR, "fall_emb.yaml"), 'r') as stream:
settings = yaml.safe_load(stream)
case = 'Risk'
label_name = 'Fall'
ex = {str(i)+'Ex':str for i in range(1, settings['ex_resolution']+1)}
ats = {str(i)+'Ats':str for i in range(1, settings['ats_resolution']+1)}
converters = {**ex, **ats}
df = file_reader.read_csv(pt.PROCESSED_DATA_DIR,
f'risk.csv',
converters=converters)
df = file_reader.read_csv(pt.PROCESSED_DATA_DIR, f'fall.csv', converters=converters)
num_cols = embedder.get_numerical_cols(df, case)
num_cols = embedder.get_numerical_cols(df, label_name)
# Extract exercises
cols_ex = [str(i)+'Ex' for i in range(1, settings['ex_resolution']+1)]
......@@ -105,9 +105,9 @@ def make_risk_count():
df = pd.concat([df, df_ex, df_ats], axis=1)
ex_columns = ['Ex_' + ex for ex in unique_ex]
ats_columns = ['Ats_' + ats for ats in unique_ats]
df = df[num_cols + ex_columns + ats_columns + [case]]
df = df[num_cols + ex_columns + ats_columns + [label_name]]
df = df.drop(['Ex_0', 'Ats_0'], axis=1)
file_writer.write_csv(df, pt.PROCESSED_DATA_DIR, 'risk_count.csv')
file_writer.write_csv(df, pt.PROCESSED_DATA_DIR, f'{label_name.lower()}_count.csv')
if __name__ == "__main__":
main()
\ No newline at end of file
......@@ -15,30 +15,30 @@ USE_GROUPING = False
ENABLE_EMB_VIZ = False
def main(ats_resolution: int = None):
for target_name in ["Complete", "Compliance", "Fall", "Risk"]:
# Load settings for target
for label_name in ["Complete", "Compliance", "Alarm", "Fall"]:
with open(Path.joinpath(pt.CONFIGS_DIR,
f'{target_name.lower()}_emb.yaml'), 'r') as stream:
f'{label_name.lower()}_emb.yaml'), 'r') as stream:
settings = yaml.safe_load(stream)
if ats_resolution == None:
ats_resolution = settings['ats_resolution']
if target_name == "Risk":
if label_name == "Fall":
ex_resolution = settings['ex_resolution']
if target_name in ["Complete", "Compliance", "Fall"]:
if label_name in ["Complete", "Compliance", "Alarm"]:
ats = {str(i)+'Ats':str for i in range(1, ats_resolution+1)}
df = file_reader.read_csv(pt.PROCESSED_DATA_DIR,
f'{target_name.lower()}.csv',
f'{label_name.lower()}.csv',
converters=ats)
else:
ex = {str(i)+'Ex':str for i in range(1, ex_resolution+1)}
ats = {str(i)+'Ats':str for i in range(1, ats_resolution+1)}
converters = {**ex, **ats}
df = file_reader.read_csv(pt.PROCESSED_DATA_DIR,
f'{target_name.lower()}.csv',
f'{label_name.lower()}.csv',
converters=converters)
if target_name in ["Complete", "Compliance", "Fall"]:
if label_name in ["Complete", "Compliance", "Alarm"]:
emb_cols = df.filter(regex='((\d+)[Ats])\w+', axis=1)
n_numerical_cols = df.shape[1] - emb_cols.shape[1] - 1
df_to_enc = df.iloc[:,n_numerical_cols:]
......@@ -47,18 +47,18 @@ def main(ats_resolution: int = None):
else:
ats_cols = [str(i)+'Ats' for i in range(1, ats_resolution+1)]
ex_cols = [str(i)+'Ex' for i in range(1, ex_resolution+1)]
df_ats_to_enc = df.filter(regex=f'Risk|((\d+)[Ats])\w+', axis=1)
df_ex_to_enc = df.filter(regex=f'Risk|((\d+)[Ex])\w+', axis=1)
df_ats_to_enc = df.filter(regex=f'Fall|((\d+)[Ats])\w+', axis=1)
df_ex_to_enc = df.filter(regex=f'Fall|((\d+)[Ex])\w+', axis=1)
df = df.drop(ats_cols + ex_cols, axis=1)
# Load embedded config
with open(Path.joinpath(pt.CONFIGS_DIR,
f"{target_name.lower()}_emb.yaml"), 'r') as stream:
f"{label_name.lower()}_emb.yaml"), 'r') as stream:
emb_cfg = yaml.safe_load(stream)
# Encode dataframe given params
model_path = Path.joinpath(pt.ROOT_DIR, emb_cfg['model_path'])
if target_name in ["Complete", "Compliance", "Fall"]:
if label_name in ["Complete", "Compliance", "Alarm"]:
df_enc = encode_dataframe(df=df_to_enc,
target_name=emb_cfg['target_name'],
batch_size=emb_cfg['batch_size'],
......@@ -90,13 +90,13 @@ def main(ats_resolution: int = None):
df_rand = pd.DataFrame(np.random.rand(len(df),1), columns=['Rand']) # add random var
if target_name in ["Complete", "Compliance", "Fall"]:
df = pd.concat([df.drop(target_name, axis=1), df_rand, df_enc, df.pop(target_name)], axis=1)
if label_name in ["Complete", "Compliance", "Alarm"]:
df = pd.concat([df.drop(label_name, axis=1), df_rand, df_enc, df.pop(label_name)], axis=1)
else:
df = pd.concat([df.drop(target_name, axis=1), df_rand, ats_enc, ex_enc,
df.pop(target_name)], axis=1)
df = pd.concat([df.drop(label_name, axis=1), df_rand, ats_enc, ex_enc,
df.pop(label_name)], axis=1)
file_writer.write_csv(df, pt.PROCESSED_DATA_DIR, f'{target_name.lower()}_emb.csv')
file_writer.write_csv(df, pt.PROCESSED_DATA_DIR, f'{label_name.lower()}_emb.csv')
def encode_dataframe(df, target_name, batch_size, train_ratio, epochs,
optimizer, network_layers, verbose, model_path):
......
#!/usr/bin/env python
import paths as pt
from tools import file_reader, file_writer, target_maker
from tools import file_reader, file_writer, labeler
from tools import preprocessor
import pandas as pd
import yaml
......@@ -8,42 +8,39 @@ from pathlib import Path
def main(ats_resolution: int = None):
clusters = file_reader.read_csv(pt.INTERIM_DATA_DIR, 'cl.csv',
converters={'CitizenId': str, 'Cluster': int})
converters={'CitizenId': str, 'Cluster': int})
screenings = file_reader.read_csv(pt.INTERIM_DATA_DIR, 'screenings.csv',
converters={'CitizenId': str})
fall_data = pd.DataFrame(file_reader.read_pickle(pt.INTERIM_DATA_DIR, 'fd.pkl'))
fall_data = fall_data.drop_duplicates(["CitizenId", "Date"])
converters={'CitizenId': str})
for target_name in ['Complete', 'Compliance', 'Fall', 'Risk']:
for label_name in ['Complete', 'Compliance', 'Alarm', 'Fall']:
# Load settings for target
with open(Path.joinpath(pt.CONFIGS_DIR,
f'{target_name.lower()}_emb.yaml'), 'r') as stream:
with open(Path.joinpath(pt.CONFIGS_DIR, f'{label_name.lower()}_emb.yaml'), 'r') as stream:
settings = yaml.safe_load(stream)
if ats_resolution == None:
ats_resolution = settings['ats_resolution']
if target_name == "Risk":
if label_name == "Fall":
ex_resolution = settings['ex_resolution']
features = settings['features']
df = screenings.copy()
df['Cluster'] = clusters['Cluster']
# Split cat columns by ATS resolution
df = preprocessor.split_cat_columns(df, col_to_split='Ats', tag='Ats',
resolution=ats_resolution)
if target_name == "Risk":
df = preprocessor.split_cat_columns(df, col_to_split='Ex', tag='Ex',
resolution=ex_resolution)
# Encode target label
if target_name == 'Complete':
df = target_maker.make_complete_target(df, settings)
elif target_name == 'Compliance':
df = target_maker.make_compliance_target(df, settings)
elif target_name == 'Fall':
df = target_maker.make_fall_target(df, settings)
if label_name == 'Complete':
df = labeler.make_complete_label(df, settings)
elif label_name == 'Compliance':
df = labeler.make_compliance_label(df, settings)
elif label_name == 'Alarm':
df = labeler.make_alarm_label(df, settings)
df['Ats'] = df['Ats'].apply(lambda x: x.replace('222718', '0'))
else:
df = target_maker.make_risk_target(df, fall_data, settings)
df = labeler.make_fall_label(df, settings)
# Split cat columns by ATS resolution
df = preprocessor.split_cat_columns(df, col_to_split='Ats', tag='Ats', resolution=ats_resolution)
if label_name == "Fall":
df = preprocessor.split_cat_columns(df, col_to_split='Ex', tag='Ex', resolution=ex_resolution)
# One-hot-encode gender variable
object_cols = ['Gender']
......@@ -53,15 +50,15 @@ def main(ats_resolution: int = None):
df['Gender_Male'] = df['Gender_Male'].astype(int)
# Concat dataframe in proper order
if target_name in ["Complete", "Compliance", "Fall"]:
if label_name in ["Complete", "Compliance", "Alarm"]:
ats_cols = df.filter(regex='Ats', axis=1)
df = pd.concat([df[features], ats_cols, df[[target_name]]], axis=1)
df = pd.concat([df[features], ats_cols, df[[label_name]]], axis=1)
else:
ats_ex_cols = df.filter(regex='Ats|Ex', axis=1)
df = pd.concat([df[features], ats_ex_cols, df[[target_name]]], axis=1)
df = pd.concat([df[features], ats_ex_cols, df[[label_name]]], axis=1)
if settings['use_real_ats_names']:
if target_name in ["Complete", "Compliance", "Fall"]:
if label_name in ["Complete", "Compliance", "Alarm"]:
ats = file_reader.read_csv(pt.REFERENCES_DIR, 'ats.csv',
converters={'ats_id': str})
df = preprocessor.replace_cat_values(df, ats)
......@@ -73,7 +70,7 @@ def main(ats_resolution: int = None):
df = preprocessor.replace_cat_values(df, ats)
df = preprocessor.replace_cat_values(df, ex)
file_writer.write_csv(df, pt.PROCESSED_DATA_DIR, f'{target_name.lower()}.csv')
file_writer.write_csv(df, pt.PROCESSED_DATA_DIR, f'{label_name.lower()}.csv')
if __name__ == "__main__":
main()
\ No newline at end of file
......@@ -9,27 +9,23 @@ from pathlib import Path