Commit bbac67c9 authored by Christian Marius Lillelund's avatar Christian Marius Lillelund
Browse files

Type improvements

parent 6800d34f
Pipeline #39917 passed with stage
in 2 minutes and 14 seconds
This diff is collapsed.
......@@ -10,8 +10,8 @@ def main():
converters = {'CitizenId': str}
df = file_reader.read_csv(cfg.INTERIM_DATA_DIR, 'timeseries.csv', converters=converters)
df = feature_maker.make_number_completed(df)
df = preprocessor.split_categorial_columns(df, col='Exercises', tag='Ex', resolution=10)
df = preprocessor.split_categorial_columns(df, col='Ats', tag='Ats', resolution=10)
df = preprocessor.split_categorical_columns(df, col='Exercises', tag='Ex', resolution=10)
df = preprocessor.split_categorical_columns(df, col='Ats', tag='Ats', resolution=10)
# Create complete feature
df['Complete'] = df.groupby(['CitizenId'])['NumberCompleted'] \
......
......@@ -57,7 +57,7 @@ def main():
df['Fall'] = pd.Series.astype(df['Fall'], dtype=int)
# Make file with ats + patient data
df = preprocessor.split_categorial_columns(df, col='Ats', tag='Ats', resolution=10)
df = preprocessor.split_categorical_columns(df, col='Ats', tag='Ats', resolution=10)
df = pd.concat([df, df.pop("Fall")], axis=1) # rearrange
file_writer.write_csv(df, cfg.PROCESSED_DATA_DIR, 'fall.csv')
......
......@@ -53,7 +53,6 @@ def make_window_features(id, data):
end_date = feature_maker.convert_date_to_datetime(end_date, date_format)
window_features['StartDate'] = start_date
window_features['EndDate'] = end_date
window_features['LastStatusDate'] = feature_maker.get_last_status_date(ssw, end_date, date_format)
n_weeks = feature_maker.get_interval_length(start_date, end_date)
window_features['NumberWeeks'] = n_weeks
......
import numpy as np
from typing import List, Tuple
import pandas as pd
import sklearn
from sklearn.model_selection import StratifiedKFold
from statistics import mean
import os
......@@ -7,14 +9,25 @@ from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from utility import metrics
def fit_random_forest(X, y):
def fit_random_forest(X: List[np.ndarray],
y: np.ndarray) -> Tuple[float,
float,
float,
np.ndarray,
RandomForestClassifier]:
clf = RandomForestClassifier(n_estimators=400,
class_weight='balanced',
random_state=0)
mean_auc, std_auc, mean_acc, cm, model = make_cross_val(clf, X, y)
return mean_auc, std_auc, mean_acc, cm, model
def make_cross_val(clf, X, y, n_splits=5, shuffle=True, random_state=0):
def make_cross_val(clf: RandomForestClassifier, X: List[np.ndarray],
y: np.ndarray, n_splits: int=5, shuffle: bool=True,
random_state: int=0) -> Tuple[float,
float,
float,
np.ndarray,
RandomForestClassifier]:
model_results_roc = list()
model_results_acc = list()
model_results_auc = list()
......@@ -59,14 +72,21 @@ def make_cross_val(clf, X, y, n_splits=5, shuffle=True, random_state=0):
mean_model_acc = mean(model_results_acc)
total_confusion_matrix = total_confusion_matrix / n_splits
return mean_model_auc, std_model_auc, mean_model_acc, total_confusion_matrix, models[len(models)-1]
return mean_model_auc, std_model_auc, mean_model_acc, \
total_confusion_matrix, models[len(models)-1]
# From https://stackoverflow.com/questions/56781373/how-to-calculate-auc-for-random-forest-model-in-sklearn
def train_and_predict(clf, X_train, X_test, y_train):
def train_and_predict(clf: RandomForestClassifier,
X_train: List[np.ndarray],
X_test: List[np.ndarray],
y_train: np.ndarray) -> Tuple[RandomForestClassifier,
np.ndarray,
np.ndarray]:
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
y_pred_proba = clf.predict_proba(X_test)[:, 1]
return clf, y_pred, y_pred_proba
def get_data_by_index(X, y, index):
def get_data_by_index(X: List[np.ndarray],
y: np.ndarray,
index: int) -> Tuple[List[np.ndarray], np.ndarray]:
return (np.array(X)[index.astype(int)], np.array(y)[index.astype(int)])
\ No newline at end of file
......@@ -5,6 +5,7 @@ import pandas as pd
import numpy as np
import os
import abc
from typing import List
class BaseCleaner(metaclass=abc.ABCMeta):
@abc.abstractmethod
......@@ -40,37 +41,38 @@ class BaseCleaner(metaclass=abc.ABCMeta):
"""Cleans the fall set"""
class Cleaner2020(BaseCleaner):
def clean_clusters(self, cl):
def clean_clusters(self, cl: pd.DataFrame) -> pd.DataFrame:
return cl
def clean_patient_data(self, ptd):
def clean_patient_data(self, ptd: pd.DataFrame) -> pd.DataFrame:
ptd = remove_citizens_without_valid_id(ptd)
return ptd
def clean_screening_content(self, sc, ptd):
def clean_screening_content(self, sc: pd.DataFrame, ptd: pd.DataFrame) -> pd.DataFrame:
sc = remove_citizens_not_in_patient_data(sc, ptd, cfg.CITIZEN_ID)
sc = merge_train_and_patient_data(sc, ptd, cfg.CITIZEN_ID)
sc = sort_dataframe(sc, [cfg.CITIZEN_ID, cfg.SCREENING_DATE])
return sc
def clean_status_set(self, ss, ptd):
def clean_status_set(self, ss: pd.DataFrame, ptd: pd.DataFrame) -> pd.DataFrame:
ss = remove_citizens_not_in_patient_data(ss, ptd, cfg.CITIZEN_ID)
ss = merge_train_and_patient_data(ss, ptd, cfg.CITIZEN_ID)
ss = sort_dataframe(ss, [cfg.CITIZEN_ID, cfg.CHANGE_DATE])
return ss
def clean_training_done(self, td, ptd):
def clean_training_done(self, td: pd.DataFrame, ptd: pd.DataFrame) -> pd.DataFrame:
td = remove_citizens_not_in_patient_data(td, ptd, cfg.CITIZEN_ID)
td = sort_dataframe(td, [cfg.CITIZEN_ID, cfg.RATING_DATE])
return td
def clean_training_cancelled(self, tc, ptd):
def clean_training_cancelled(self, tc: pd.DataFrame, ptd: pd.DataFrame) -> pd.DataFrame:
tc = remove_citizens_not_in_patient_data(tc, ptd, cfg.CITIZEN_ID)
tc = merge_train_and_patient_data(tc, ptd, cfg.CITIZEN_ID)
tc = sort_dataframe(tc, [cfg.CITIZEN_ID, cfg.RATING_DATE])
return tc
def clean_assistive_aids(self, ats, ic, ids=None):
def clean_assistive_aids(self, ats: pd.DataFrame, ic: pd.DataFrame,
ids: List[str]=None) -> pd.DataFrame:
ats = sort_dataframe(ats, [cfg.CITIZEN_ID, cfg.LEND_DATE])
ats = remove_citizens_without_valid_id(ats)
ats = remove_rows_with_old_dates(ats, cfg.LEND_DATE)
......@@ -79,45 +81,46 @@ class Cleaner2020(BaseCleaner):
ats = drop_invalid_devices(ats, ic)
return ats
def clean_fall_data(self, fd):
def clean_fall_data(self, fd: pd.DataFrame) -> pd.DataFrame:
fd = remove_citizens_without_valid_id(fd)
fd = sort_dataframe(fd, [cfg.CITIZEN_ID, cfg.DATE])
return fd
class Cleaner2019(BaseCleaner):
def clean_clusters(self, cl):
def clean_clusters(self, cl: pd.DataFrame) -> pd.DataFrame:
return cl
def clean_patient_data(self, ptd):
def clean_patient_data(self, ptd: pd.DataFrame) -> pd.DataFrame:
ptd = remove_citizens_without_valid_id(ptd)
return ptd
def clean_screening_content(self, sc, ptd):
def clean_screening_content(self, sc: pd.DataFrame, ptd: pd.DataFrame) -> pd.DataFrame:
sc = remove_citizens_not_in_patient_data(sc, ptd, cfg.PATIENT_ID)
sc = remove_screenings_without_exercises(sc)
sc = merge_train_and_patient_data(sc, ptd, cfg.PATIENT_ID)
sc = sort_dataframe(sc, [cfg.CITIZEN_ID, cfg.SCREENING_DATE])
return sc
def clean_status_set(self, ss, ptd):
def clean_status_set(self, ss: pd.DataFrame, ptd: pd.DataFrame) -> pd.DataFrame:
ss = remove_citizens_not_in_patient_data(ss, ptd, cfg.PATIENT_ID)
ss = merge_train_and_patient_data(ss, ptd, cfg.PATIENT_ID)
ss = sort_dataframe(ss, [cfg.CITIZEN_ID, cfg.CHANGE_DATE])
return ss
def clean_training_done(self, td, ptd):
def clean_training_done(self, td: pd.DataFrame, ptd: pd.DataFrame) -> pd.DataFrame:
td = remove_citizens_not_in_patient_data(td, ptd, cfg.PATIENT_ID)
td = merge_train_and_patient_data(td, ptd, cfg.PATIENT_ID)
td = sort_dataframe(td, [cfg.CITIZEN_ID, cfg.RATING_DATE])
return td
def clean_training_cancelled(self, tc, ptd):
def clean_training_cancelled(self, tc: pd.DataFrame, ptd: pd.DataFrame) -> pd.DataFrame:
tc = remove_citizens_not_in_patient_data(tc, ptd, cfg.PATIENT_ID)
tc = merge_train_and_patient_data(tc, ptd, cfg.PATIENT_ID)
tc = sort_dataframe(tc, [cfg.CITIZEN_ID, cfg.RATING_DATE])
return tc
def clean_assistive_aids(self, ats, ic, ids=None):
def clean_assistive_aids(self, ats: pd.DataFrame, ic: pd.DataFrame,
ids:List[str]=None) -> pd.DataFrame:
ats = sort_dataframe(ats, [cfg.CITIZEN_ID, cfg.LEND_DATE])
ats = filter_ats_on_ids(ats, ids)
ats = remove_rows_with_old_dates(ats, cfg.LEND_DATE)
......@@ -127,21 +130,23 @@ class Cleaner2019(BaseCleaner):
ats = drop_invalid_devices(ats, ic)
return ats
def clean_fall_data(self, fd):
def clean_fall_data(self, fd: pd.DataFrame) -> pd.DataFrame:
raise NotImplementedError
def drop_invalid_devices(ats, iso_classes):
def drop_invalid_devices(ats: pd.DataFrame, iso_classes: pd.DataFrame) -> pd.DataFrame:
return ats[ats[cfg.DEV_ISO_CLASS].isin(iso_classes.DevISOClass)]
def remove_screenings_without_exercises(df):
def remove_screenings_without_exercises(df: pd.DataFrame) -> pd.DataFrame:
df = df[df[cfg.EXERCISE_CONTENT] != 'nan']
return df
def remove_citizens_not_in_patient_data(train_data, patient_data, id):
def remove_citizens_not_in_patient_data(train_data: pd.DataFrame,
patient_data: pd.DataFrame,
id: str) -> pd.DataFrame:
data = train_data[train_data[id].isin(patient_data[id].unique())]
return data
def remove_citizens_without_valid_id(df):
def remove_citizens_without_valid_id(df: pd.DataFrame) -> pd.DataFrame:
df = df[df[cfg.CITIZEN_ID] != "0000000000"]
df = df[df[cfg.CITIZEN_ID] != '0']
df = df[df[cfg.CITIZEN_ID] != "#VALUE!"]
......@@ -149,24 +154,26 @@ def remove_citizens_without_valid_id(df):
df = df[df[cfg.CITIZEN_ID] != '681']
return df
def merge_train_and_patient_data(train_data, patient_data, key):
def merge_train_and_patient_data(train_data: pd.DataFrame,
patient_data: pd.DataFrame,
key: str) -> pd.DataFrame:
return pd.merge(train_data, patient_data, on=key)
def sort_dataframe(data, by):
def sort_dataframe(data: pd.DataFrame, by: str) -> pd.DataFrame:
return data.sort_values(by)
def filter_ats_on_ids(ats, ids):
def filter_ats_on_ids(ats: pd.DataFrame, ids: List[str]) -> pd.DataFrame:
return ats[ats[cfg.CITIZEN_ID].isin(ids)]
def remove_tainted_histories(ats):
def remove_tainted_histories(ats: pd.DataFrame) -> pd.DataFrame:
tained_ids = ats[ats[cfg.DEV_HMI_NUMBER] == '899,999'][cfg.CITIZEN_ID].unique()
ats = ats[np.logical_not(ats[cfg.CITIZEN_ID].isin(tained_ids))]
return ats
def remove_deprecated_device_data(ats):
def remove_deprecated_device_data(ats: pd.DataFrame) -> pd.DataFrame:
return ats[ats[cfg.DEV_HMI_NUMBER] != '899,999']
def remove_rows_with_old_dates(ats, col):
def remove_rows_with_old_dates(ats: pd.DataFrame, col: str) -> pd.DataFrame:
ats[col] = pd.to_datetime(ats[col])
mask = (ats[col] >= '1900-01-01') & (ats[col] <= pd.Timestamp('today'))
return ats.loc[mask]
\ No newline at end of file
from pathlib import Path
import pandas as pd
from ast import literal_eval
import pickle
def read_pickle(path, filename):
return pd.read_pickle(Path.joinpath(path, filename))
from typing import List
def read_csv_path(path, header='infer', sep=',', names=None, converters=None):
return pd.read_csv(path, header=header, sep=sep, names=names,
converters=converters)
def read_pickle(path: Path, filename: str):
return pd.read_pickle(Path.joinpath(path, filename))
def read_csv(path, filename, header='infer', sep=',', names=None, converters=None):
def read_csv(path: Path, filename: str, header: str='infer',
sep: str=',', names: List[str]=None, converters: dict=None):
return pd.read_csv(Path.joinpath(path, filename), header=header,
sep=sep, names=names, converters=converters)
def read_excelfile(path, filename, converters=None):
def read_excelfile(path: Path, filename: str, converters: dict=None):
file = pd.ExcelFile(Path.joinpath(path, filename))
return file.parse(file.sheet_names[0], converters=converters)
def read_excelfile_sheets(path, filename, n_sheets, converters=None):
def read_excelfile_sheets(path: Path, filename: str, n_sheets: int, converters: dict=None):
file = pd.ExcelFile(Path.joinpath(path, filename))
full_file = pd.DataFrame()
for i in range(n_sheets):
df = file.parse(file.sheet_names[i], converters=converters)
full_file = pd.concat([full_file, df])
return full_file
\ No newline at end of file
return full_file
\ No newline at end of file
......@@ -5,27 +5,12 @@ import dill
from typing import List
from network.network_config import *
def write_csv(df, path, file_name, date_format='%d-%m-%Y', index=False):
df.to_csv(Path.joinpath(path, file_name), date_format=date_format, index=index)
def write_csv(df: pd.DataFrame, path: Path, file_name: str,
date_format: str='%d-%m-%Y',
index: bool=False) -> None:
df.to_csv(Path.joinpath(path, file_name),
date_format=date_format,
index=index)
def write_pickle(df, path, file_name):
df.to_pickle(Path.joinpath(path, file_name))
def write_model(model, path, file_name):
with open(str(Path.joinpath(path, file_name)), 'wb') as output:
pickle.dump(model, output)
def write_explainer(exp, path, file_name):
with open(str(Path.joinpath(path, file_name)), 'wb') as output:
dill.dump(exp, output)
def write_explanation(exp, path, file_name):
exp.save_to_file(str(Path.joinpath(path, file_name)))
def save_weights(weights: List, config: AirNet) -> None:
with open(config.get_weights_path(), 'wb') as f:
pickle.dump(weights, f, -1)
def save_labels(labels: List, config: AirNet) -> None:
with open(config.get_labels_path(), 'wb') as f:
pickle.dump(labels, f, -1)
\ No newline at end of file
def write_pickle(df: pd.DataFrame, path: Path, file_name: str) -> None:
df.to_pickle(Path.joinpath(path, file_name))
\ No newline at end of file
......@@ -46,10 +46,10 @@ class BaseParser(metaclass=abc.ABCMeta):
"""Parse the DiGiRehab fall data set"""
class Parser2020(BaseParser):
def assert_datetime(self, entry):
def assert_datetime(self, entry: str):
return pd.to_datetime(entry, format='%d-%m-%Y', errors='coerce') is not pd.NaT
def parse_status_set(self, file_name, path):
def parse_status_set(self, file_name: str, path: Path):
index_col = 'BorgerID'
converters = {index_col: str} # pass as string to avoid float conversion
ss_df = file_reader.read_excelfile(path, file_name, converters=converters)
......@@ -76,7 +76,7 @@ class Parser2020(BaseParser):
return df
def parse_training_cancelled(self, file_name, path):
def parse_training_cancelled(self, file_name: str, path: Path):
index_col = 'BorgerID'
converters = {index_col: str} # pass as string to avoid float conversion
tc_df = file_reader.read_excelfile(path, file_name, converters=converters)
......@@ -99,7 +99,7 @@ class Parser2020(BaseParser):
return df
def parse_screening_content(self, file_name, path):
def parse_screening_content(self, file_name: str, path: Path):
index_col = 'BorgerID'
converters = {index_col: str} # pass as string to avoid float conversion
sc_df = file_reader.read_excelfile(path, file_name, converters=converters)
......@@ -133,7 +133,7 @@ class Parser2020(BaseParser):
return df
def parse_training_done(self, file_name, path):
def parse_training_done(self, file_name: str, path: Path):
index_col = 'BorgerID'
converters = {index_col: str} # pass as string to avoid float conversion
td_df = file_reader.read_excelfile(path, file_name, converters=converters)
......@@ -161,7 +161,7 @@ class Parser2020(BaseParser):
df[cfg.BIRTH_YEAR] = pd.Series.astype(df[cfg.BIRTH_YEAR], dtype=int)
return df
def parse_assistive_aids(self, file_name, path, n_sheets=6):
def parse_assistive_aids(self, file_name: str, path: Path, n_sheets: int=6):
converters = {'BorgerID': str,
'HJAELPEMIDDELHMINUMMER': str,
'HJAELPEMIDDELISOKLASSE': str} # pass as str to avoid int conversion
......@@ -195,12 +195,12 @@ class Parser2020(BaseParser):
df[cfg.PRICE] = pd.Series.astype(df[cfg.PRICE], dtype=str)
return df
def parse_clusters(self, file_name, path):
def parse_clusters(self, file_name: str, path: Path):
cluster_file = Path.joinpath(path, file_name)
df = pd.read_csv(cluster_file)
return df
def parse_iso_classes(self, file_name, path):
def parse_iso_classes(self, file_name: str, path: Path):
isoclass_file = Path.joinpath(path, file_name)
df = pd.read_csv(isoclass_file,
header=None,
......@@ -209,7 +209,7 @@ class Parser2020(BaseParser):
converters={i: str for i in range(0, 10000)})
return df
def parse_fall_data(self, file_name, path):
def parse_fall_data(self, file_name: str, path: Path):
converters = {'BorgerId': str} # pass as string to avoid float conversion
df = file_reader.read_excelfile(path, file_name, converters)
df = df.rename(columns={'BorgerId': cfg.CITIZEN_ID,
......@@ -230,7 +230,7 @@ class Parser2020(BaseParser):
return df
def parse_patient_data(self, file_name, path):
def parse_patient_data(self, file_name: str, path: Path):
raise NotImplementedError
class Parser2019(BaseParser):
......
from sklearn.model_selection import train_test_split
import pandas as pd
from sklearn.preprocessing import OneHotEncoder
from ast import literal_eval
from ast import Str, literal_eval
from sklearn.preprocessing import MinMaxScaler
from tools import feature_maker
from typing import List, Tuple
......@@ -19,7 +19,6 @@ def series_to_list(series: pd.Series) -> List:
list_cols = []
for item in series:
list_cols.append(item)
return list_cols
def sample(X: np.ndarray, y: np.ndarray, n: int) -> Tuple[np.ndarray, np.ndarray]:
......@@ -45,8 +44,8 @@ def get_X_y(df: pd.DataFrame, name_target: str) -> Tuple[List, List]:
return X_list, y_list
def label_encode(data: List, n_numerical_cols: int) -> [np.ndarray, List[LabelEncoder]]:
def label_encode(data: List[np.ndarray], n_numerical_cols: int) -> Tuple[List[np.ndarray],
List[LabelEncoder]]:
"""
This method is used to perform Label Encoding on a given list
:param data: the list containing the items to be encoded
......@@ -76,12 +75,12 @@ def transpose_to_list(X: np.ndarray) -> List[np.ndarray]:
return features_list
def get_ats_list(ats):
def get_ats_list(ats: pd.DataFrame) -> pd.DataFrame:
df = pd.DataFrame(ats.groupby(['CitizenId'])['DevISOClass'].apply(lambda x: ",".join(x))).reset_index()
df = df.rename(columns={'DevISOClass': 'Ats'})
return df
def get_class_weights(neg: int, pos: int):
def get_class_weights(neg: int, pos: int) -> dict:
total = neg + pos
weight_for_0 = (1 / neg)*(total)/2.0
weight_for_1 = (1 / pos)*(total)/2.0
......@@ -92,7 +91,8 @@ def prepare_network_data(df: pd.DataFrame,
target_name: str,
n_numerical_cols: int,
train_ratio: float,
use_sparsity: bool) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List]:
use_sparsity: bool) -> Tuple[np.ndarray, np.ndarray,
np.ndarray, np.ndarray, List]:
X, y = get_X_y(df, target_name)
X, labels = label_encode(X, n_numerical_cols)
y = np.array(y)
......@@ -109,7 +109,7 @@ def prepare_network_data(df: pd.DataFrame,
return X_train, X_val, y_train, y_val, labels
def split_categorial_columns(df: pd.DataFrame, col: str, tag: str, resolution: int):
def split_categorical_columns(df: pd.DataFrame, col: str, tag: str, resolution: int):
split = pd.DataFrame(df[col].str.split(pat=",", expand=True))
split = split.drop(split.iloc[:, resolution:], axis=1)
split = split.fillna(0)
......@@ -122,7 +122,8 @@ def split_categorial_columns(df: pd.DataFrame, col: str, tag: str, resolution: i
pass
return df
def create_improve_split(features, threshold):
def create_improve_split(features: pd.DataFrame, threshold: float) -> Tuple[np.ndarray, np.ndarray,
np.ndarray, np.ndarray]:
X = feature_maker.make_improve_feature_absolute(features, threshold)
y = X["Improve"]
......@@ -131,24 +132,24 @@ def create_improve_split(features, threshold):
return X_train, X_test, y_train, y_test
def drop_id_columns(df: pd.DataFrame):
def drop_id_columns(df: pd.DataFrame) -> pd.DataFrame:
df = df.drop(['CitizenId'], axis=1)
return df
def drop_date_columns(df: pd.DataFrame):
def drop_date_columns(df: pd.DataFrame) -> pd.DataFrame:
df = df.drop(['StartDate', 'EndDate'], axis=1)
return df
def drop_reason_cols(df):
def drop_reason_cols(df: pd.DataFrame) -> pd.DataFrame:
reason_cols = [col for col in df.columns if 'Reason' in col]
df = df.drop(reason_cols, axis=1)
return df
def drop_rows_with_no_needs_start_bl(df):
def drop_rows_with_no_needs_start_bl(df: pd.DataFrame) -> pd.DataFrame:
df = df[df['NeedsStartBaseline'] != 0]
return df
def convert_dates(X, cols):
def convert_dates(X: pd.DataFrame, cols: List[str]) -> pd.DataFrame:
if cols_exist(X, cols):
for col in cols:
X[col] = pd.to_datetime(X[col], infer_datetime_format=True)
......@@ -157,7 +158,7 @@ def convert_dates(X, cols):
raise ValueError('DataFrame did not contain argument columns')
return X
def encode_vector_onehot(X, cols):
def encode_vector_onehot(X: pd.DataFrame, cols: List[str]) -> pd.DataFrame:
if cols_exist(X, cols):
ohe = OneHotEncoder(handle_unknown='ignore', sparse=False)
X_enc = pd.DataFrame(ohe.fit_transform(X[cols]))
......@@ -169,7 +170,7 @@ def encode_vector_onehot(X, cols):
else:
raise ValueError('DataFrame did not contain argument columns')
def encode_vector_dummy(X, cols):
def encode_vector_dummy(X: pd.DataFrame, cols: List[str]) -> pd.DataFrame:
if cols_exist(X, cols):
for col in cols:
X_dummies = pd.get_dummies(X[col].apply(pd.Series).stack(dropna=False),
......@@ -181,7 +182,8 @@ def encode_vector_dummy(X, cols):
else:
raise ValueError('DataFrames did not contain argument columns')
def encode_vectors_dummy(X_train, X_test, cols):
def encode_vectors_dummy(X_train: pd.DataFrame, X_test: pd.DataFrame,
cols: List[str]) -> Tuple[pd.DataFrame, pd.DataFrame]:
if cols_exist(X_train, cols) and cols_exist(X_test, cols):
for col in cols:
X_train_dummies = pd.get_dummies(X_train[col].apply(pd.Series).stack(dropna=False),
......@@ -201,7 +203,8 @@ def encode_vectors_dummy(X_train, X_test, cols):
else:
raise ValueError('DataFrames did not contain argument columns')
def encode_vectors_onehot(X_train, X_test, cols):
def encode_vectors_onehot(X_train: pd.DataFrame, X_test: pd.DataFrame,
cols: List[str]) -> Tuple[pd.DataFrame, pd.DataFrame]:
if cols_exist(X_train, cols) and cols_exist(X_test, cols):
ohe = OneHotEncoder(handle_unknown='ignore', sparse=False)
X_train_enc = pd.DataFrame(ohe.fit_transform(X_train[cols]))
......@@ -218,38 +221,38 @@ def encode_vectors_onehot(X_train, X_test, cols):
else:
raise ValueError('DataFrames did not contain argument columns')
def cols_exist(df, cols):
def cols_exist(df: pd.DataFrame, cols: List[str]) -> bool:
return set(cols).issubset(df.columns)
def filter_first_screening(X):
def filter_first_screening(X: pd.DataFrame) -> pd.DataFrame:
X = X.loc[X.ScreeningNo == 1]
X = X.reset_index(drop=True)
return X
def filter_ideal_candidates(X):
def filter_ideal_candidates(X: pd.DataFrame) -> pd.DataFrame:
X = X.loc[(X['NumberWeeksSum'] >= 8) & (X['NumberTrainingSum'] >= 7)]
X = X.drop_duplicates(subset='CitizenId')
X = X.reset_index(drop=True)