Commit 53c68a44 authored by thecml's avatar thecml
Browse files

improved surv script, added alarm to api

parent 931302d3
Pipeline #93165 passed with stage
in 4 minutes and 19 seconds
......@@ -2,7 +2,7 @@
# Settings for data scripts -------------------------------------------------
#
ats_delimiter: 6
ats_iso_length: 6
ats_resolution: 10
threshold_weeks: 8
threshold_training: 10
......
......@@ -3,31 +3,31 @@
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\r\n",
"import config as cfg\r\n",
"from tools import file_reader\r\n",
"from pathlib import Path\r\n",
"\r\n",
"df = file_reader.read_pickle(cfg.INTERIM_DATA_DIR, 'ats.pkl')\r\n",
"df['DevISOClass'] = df['DevISOClass'].apply(lambda x: x[:6])\r\n",
"df = df.drop_duplicates(subset=['DevISOClass'])\r\n",
"df = df[['DevHMIName', 'DevISOClass']]\r\n",
"df = df.sort_values('DevISOClass')\r\n",
"columns_titles = [\"DevISOClass\",\"DevHMIName\"]\r\n",
"df = df.reindex(columns=columns_titles)\r\n",
"df = df.rename(columns={\"DevISOClass\": \"ats_id\", \"DevHMIName\": \"ats_name\"})\r\n",
"\r\n",
"ats = file_reader.read_csv(cfg.REFERENCES_DIR, 'ats.csv',\r\n",
" converters={'ats_id': str})\r\n",
"\r\n",
"df = ats.merge(df, how='outer', on=['ats_id']).drop_duplicates(['ats_id'], keep='first')\r\n",
"\r\n",
"file_name = f\"ats full.csv\"\r\n",
"import pandas as pd\n",
"import config as cfg\n",
"from tools import file_reader\n",
"from pathlib import Path\n",
"\n",
"df = file_reader.read_pickle(cfg.INTERIM_DATA_DIR, 'ats.pkl')\n",
"df['DevISOClass'] = df['DevISOClass'].apply(lambda x: x[:6])\n",
"df = df.drop_duplicates(subset=['DevISOClass'])\n",
"df = df[['DevHMIName', 'DevISOClass']]\n",
"df = df.sort_values('DevISOClass')\n",
"columns_titles = [\"DevISOClass\",\"DevHMIName\"]\n",
"df = df.reindex(columns=columns_titles)\n",
"df = df.rename(columns={\"DevISOClass\": \"ats_id\", \"DevHMIName\": \"ats_name\"})\n",
"\n",
"ats = file_reader.read_csv(cfg.REFERENCES_DIR, 'ats.csv',\n",
" converters={'ats_id': str})\n",
"\n",
"df = ats.merge(df, how='outer', on=['ats_id']).drop_duplicates(['ats_id'], keep='first')\n",
"\n",
"file_name = f\"ats full.csv\"\n",
"df.to_csv(Path.joinpath(cfg.REFERENCES_DIR, file_name), index=False)"
],
"outputs": [],
"metadata": {}
]
}
],
"metadata": {
......@@ -51,4 +51,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
\ No newline at end of file
}
This diff is collapsed.
......@@ -723,7 +723,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.11"
"version": "3.8.8"
},
"orig_nbformat": 4
},
......
......@@ -17,6 +17,7 @@ from fastapi import Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
import yaml
import pickle
app = FastAPI(title='AIR API', version='0.1',
description='A simple API that classifies citizens based on data')
......@@ -86,7 +87,11 @@ class InputData(pydantic.BaseModel):
LoanPeriod: float
Ats: str
class OutputData(pydantic.BaseModel):
class AlarmOutputData(pydantic.BaseModel):
EventTimes: list
SurvivalProbs: list
class TrainingOutputData(pydantic.BaseModel):
CompleteProb: float
FallProb: float
Compliance: int
......@@ -160,28 +165,34 @@ def refresh(Authorize: AuthJWT = Depends()):
new_access_token = Authorize.create_access_token(subject=current_user)
return {"access_token": new_access_token}
@app.post('/predict', response_model=OutputData, tags=["ai"])
def predict(incoming_data: InputData):
@app.post('/predict_alarm', response_model=AlarmOutputData, tags=["ai"])
def predict_alarm(incoming_data: InputData):
settings = load_settings('config.yaml')
ats_resolution = settings['ats_resolution']
data = validate_data(incoming_data)
data = incoming_data.dict()
if not data['Gender'] in [0, 1]:
raise HTTPException(status_code=400, detail="Invalid gender, check input")
if not data['BirthYear'] in range(1900, 1999):
raise HTTPException(status_code=400, detail="Invalid birth year, check input")
if not data['LoanPeriod'] >= 0:
raise HTTPException(status_code=400, detail="Invalid loan period, check input")
incoming_ats = [x.strip(' ') for x in data['Ats'].split(",")]
if any(x in list(['222718']) for x in incoming_ats) == True:
raise HTTPException(status_code=400, detail="An alarm cannot be in feature set")
if data['Ats'] != '':
incoming_ats = [x.strip(' ') for x in data['Ats'].split(",")]
ats_converter = {'ats_id': str}
ats_df = read_dataframe('ats.csv', converters=ats_converter)
if all(x in list(ats_df['ats_id']) for x in incoming_ats) != True:
raise HTTPException(status_code=400, detail="Ats not found, check ats list")
else:
data['Ats'] = '0'
df = prepare_data(data, ats_resolution)
model = read_joblib("alarm_rsf.joblib")
label_encoders = read_pickle("alarm_labels.pkl")
df_for_alarm = add_label_encoding(df.copy(), label_encoders, ats_resolution)
surv_probs = model.predict_survival_function(df_for_alarm, return_array=True)
return {
'EventTimes': [int(x) for x in model.event_times_],
'SurvivalProbs': [float(x) for x in surv_probs[0]]
}
@app.post('/predict_training', response_model=TrainingOutputData, tags=["ai"])
def predict_training(incoming_data: InputData):
settings = load_settings('config.yaml')
ats_resolution = settings['ats_resolution']
data = validate_data(incoming_data)
df = prepare_data(data, ats_resolution)
complete_model = read_joblib(f'complete_xgboost.joblib')
......@@ -199,13 +210,6 @@ def predict(incoming_data: InputData):
else:
compliance_prob = 0
compliance = 0 if compliance_prob < 0.5 else 1
#cluster_id = int(df.iloc[0]['Cluster'])
#cluster_converter = {str(i):str for i in range(1, 20)}
#clusters = read_dataframe('clusters.csv', converters=cluster_converter)
#cluster_values = list(clusters.iloc[:, cluster_id])
#complete_shap_values = get_shap_values(complete_model, X_test=df_for_complete)
#fall_shap_values = get_shap_values(fall_model, X_test=df_for_fall)
complete_arguments = generate_arguments(df, ats_resolution, "Complete", float(complete_prob))
fall_arguments = generate_arguments(df, ats_resolution, "Fall", float(fall_prob))
......@@ -217,6 +221,26 @@ def predict(incoming_data: InputData):
'CompleteArguments': complete_arguments,
'FallArguments': fall_arguments,
}
def validate_data(incoming_data: InputData):
data = incoming_data.dict()
if not data['Gender'] in [0, 1]:
raise HTTPException(status_code=400, detail="Invalid gender, check input")
if not data['BirthYear'] in range(1900, 1999):
raise HTTPException(status_code=400, detail="Invalid birth year, check input")
if not data['LoanPeriod'] >= 0:
raise HTTPException(status_code=400, detail="Invalid loan period, check input")
if data['Ats'] != '':
incoming_ats = [x.strip(' ') for x in data['Ats'].split(",")]
ats_converter = {'ats_id': str}
ats_df = read_dataframe('ats.csv', converters=ats_converter)
if all(x in list(ats_df['ats_id']) for x in incoming_ats) != True:
raise HTTPException(status_code=400, detail="Ats not found, check ats list")
else:
data['Ats'] = '0'
return data
def add_embeddings(df: pd.DataFrame, case: str, ats_resolution: int) -> pd.DataFrame:
for i in range(1, ats_resolution+1):
......@@ -226,6 +250,13 @@ def add_embeddings(df: pd.DataFrame, case: str, ats_resolution: int) -> pd.DataF
df[column] = pd.to_numeric(df[column])
return df
def add_label_encoding(df: pd.DataFrame, encoders, ats_resolution: int) -> pd.DataFrame:
ats_cols = [f"{i}Ats" for i in range(1, ats_resolution+1)]
for col_name in ats_cols:
le = encoders[col_name]
df.loc[:, col_name] = le.transform(df.loc[:, col_name].astype(str))
return df
def generate_arguments(df: pd.DataFrame, ats_resolution: int, case: str, prob: float):
arguments = list()
......@@ -288,15 +319,10 @@ def prepare_data(data: dict, ats_resolution: int) -> pd.DataFrame:
df['BirthYear'] = df['BirthYear'].apply(lambda x: int(str(x)[2:]))
cols_ats = [str(i)+'Ats' for i in range(1, ats_resolution+1)]
header_list = ['Gender_Male', 'Gender_Female', 'BirthYear',
'Cluster', 'LoanPeriod', 'NumberAts'] + cols_ats
header_list = ['Gender', 'BirthYear', 'LoanPeriod', 'NumberAts'] + cols_ats
df = df.reindex(columns=header_list)
df = df.fillna('0')
citizen_ats = df.filter(regex='((\d+)[Ats])\w+', axis=1).values
df['Cluster'] = get_cluster(citizen_ats)
df['Cluster'] = pd.to_numeric(df['Cluster'])
return df
def get_shap_values(model, X_train=None, X_test=None) -> List:
......@@ -332,6 +358,12 @@ def read_embedding(filename: str):
embedding_dict = {rows[0]:rows[1] for rows in reader}
return embedding_dict
def read_pickle(filename: str) -> any:
dir_path = os.path.dirname(os.path.realpath(__file__))
with open(f'{dir_path}/models/{filename}', 'rb') as f:
data = pickle.load(f)
return data
def read_joblib(filename: str) -> any:
dir_path = os.path.dirname(os.path.realpath(__file__))
return joblib.load(f'{dir_path}/models/{filename}')
......
......@@ -20,7 +20,6 @@ def main():
training_cancelled = cleaner2020.clean_training_cancelled(tc, patient_data)
assistive_aids = cleaner2020.clean_assistive_aids(ats, ic)
file_writer.write_pickle(patient_data, pt.INTERIM_DATA_DIR, 'pd.pkl')
file_writer.write_pickle(screening_content, pt.INTERIM_DATA_DIR, 'sc.pkl')
file_writer.write_pickle(status_set, pt.INTERIM_DATA_DIR, 'ss.pkl')
file_writer.write_pickle(training_done, pt.INTERIM_DATA_DIR, 'td.pkl')
......
......@@ -3,23 +3,23 @@ import paths as pt
from tools import file_reader, file_writer
from tools import preprocessor
from utility import embedder
from utility.settings import load_settings
import pandas as pd
import numpy as np
from pathlib import Path
import yaml
from utility.settings import load_settings
def main():
for label_name in ["Complete", "Compliance", "Alarm", "Fall"]:
settings = load_settings(f'{label_name.lower()}_emb.yaml')
for label_name in ["Complete", "Compliance", "Fall"]:
data_settings = load_settings('data.yaml')
ats = {str(i)+'Ats':str for i in range(1, settings['ats_resolution']+1)}
ats = {str(i)+'Ats':str for i in range(1, data_settings['ats_resolution']+1)}
df = file_reader.read_csv(pt.PROCESSED_DATA_DIR,
f'{label_name.lower()}.csv',
converters=ats)
cols_ats = [str(i)+'Ats' for i in range(1, settings['ats_resolution']+1)]
unique_ats = [df[f'{i}Ats'].unique() for i in range(1, settings['ats_resolution']+1)]
cols_ats = [str(i)+'Ats' for i in range(1, data_settings['ats_resolution']+1)]
unique_ats = [df[f'{i}Ats'].unique() for i in range(1, data_settings['ats_resolution']+1)]
unique_ats = list(set(np.concatenate(unique_ats)))
df_ats = preprocessor.extract_cat_count(df, unique_ats, cols_ats, 'Ats_')
......
#!/usr/bin/env python
from tools import file_reader, file_writer
from tools import preprocessor, neural_embedder
from utility.settings import load_settings
import pandas as pd
import numpy as np
import paths as pt
......@@ -15,13 +16,10 @@ USE_GROUPING = False
ENABLE_EMB_VIZ = False
def main(ats_resolution: int = None):
for label_name in ["Complete", "Compliance", "Alarm", "Fall"]:
with open(Path.joinpath(pt.CONFIGS_DIR,
f'{label_name.lower()}_emb.yaml'), 'r') as stream:
settings = yaml.safe_load(stream)
for label_name in ["Complete", "Compliance", "Fall"]:
data_settings = load_settings('data.yaml')
if ats_resolution == None:
ats_resolution = settings['ats_resolution']
ats_resolution = data_settings['ats_resolution']
ats = {str(i)+'Ats':str for i in range(1, ats_resolution+1)}
df = file_reader.read_csv(pt.PROCESSED_DATA_DIR,
......@@ -33,22 +31,18 @@ def main(ats_resolution: int = None):
df_to_enc = df.iloc[:,n_numerical_cols:]
ats_cols = [str(i)+'Ats' for i in range(1, ats_resolution+1)]
df = df.drop(ats_cols, axis=1)
# Load embedded config
with open(Path.joinpath(pt.CONFIGS_DIR,
f"{label_name.lower()}_emb.yaml"), 'r') as stream:
emb_cfg = yaml.safe_load(stream)
# Encode dataframe given params
model_path = Path.joinpath(pt.ROOT_DIR, emb_cfg['model_path'])
target_settings = load_settings(f'{label_name.lower()}_emb.yaml')
model_path = Path.joinpath(pt.ROOT_DIR, target_settings['model_path'])
df_enc = encode_dataframe(df=df_to_enc,
target_name=emb_cfg['target_name'],
batch_size=emb_cfg['batch_size'],
train_ratio=emb_cfg['train_ratio'],
epochs=emb_cfg['num_epochs'],
optimizer=emb_cfg['optimizer'],
network_layers=emb_cfg['network_layers'],
verbose=emb_cfg['verbose'],
target_name=target_settings['target_name'],
batch_size=target_settings['batch_size'],
train_ratio=target_settings['train_ratio'],
epochs=target_settings['num_epochs'],
optimizer=target_settings['optimizer'],
network_layers=target_settings['network_layers'],
verbose=target_settings['verbose'],
model_path=model_path)
df_rand = pd.DataFrame(np.random.rand(len(df),1), columns=['Rand']) # add random var
......
......@@ -15,15 +15,14 @@ def main(ats_resolution: int = None):
converters={'CitizenId': str})
data_settings = load_settings('data.yaml')
if ats_resolution == None:
ats_resolution = data_settings['ats_resolution']
df = screenings.copy()
df['Cluster'] = clusters['Cluster']
accum_screenings = labeler.accumulate_screenings(df, data_settings)
for label_name in ['Complete', 'Compliance', 'Alarm', 'Fall']:
for label_name in ['Complete', 'Compliance', 'Fall']:
target_settings = load_settings(f'{label_name.lower()}_emb.yaml')
if ats_resolution == None:
ats_resolution = target_settings['ats_resolution']
features = target_settings['features']
# Encode target label
......@@ -31,10 +30,6 @@ def main(ats_resolution: int = None):
df = labeler.make_complete_label(accum_screenings)
elif label_name == 'Compliance':
df = labeler.make_compliance_label(accum_screenings)
elif label_name == 'Alarm':
df = labeler.make_alarm_label(accum_screenings)
df = df.replace({'Ats': {'22271812': '0', '22271813':'0',
'22271814': '0', '22271816': '0'}}, regex=True)
else:
df = labeler.make_fall_label(accum_screenings)
......
......@@ -11,7 +11,7 @@ from utility.settings import load_settings
def main():
for label_name in ["Complete", "Compliance", "Alarm", "Fall"]:
settings = load_settings(f'{label_name.lower()}_emb.yaml')
settings = load_settings("data.yaml")
ats = {str(i)+'Ats':str for i in range(1, settings['ats_resolution']+1)}
df = file_reader.read_csv(pt.PROCESSED_DATA_DIR, f'{label_name.lower()}.csv', converters=ats)
......@@ -21,7 +21,6 @@ def main():
df = pd.concat([df.drop(ats_cols + [label_name], axis=1),
df_enc, df[[label_name]]], axis=1)
# Save dataframe
file_writer.write_csv(df, pt.PROCESSED_DATA_DIR, f'{label_name.lower()}_ohe.csv')
if __name__ == "__main__":
......
......@@ -5,6 +5,7 @@ import numpy as np
import pandas as pd
import paths as pt
from tools import file_reader, file_writer, inputter
from utility.settings import load_settings
from utility import data_dto, dataset
from pandas.tseries.offsets import DateOffset
......@@ -14,9 +15,7 @@ def main():
td = file_reader.read_pickle(pt.INTERIM_DATA_DIR, 'td.pkl')
tc = file_reader.read_pickle(pt.INTERIM_DATA_DIR, 'tc.pkl')
ats = file_reader.read_pickle(pt.INTERIM_DATA_DIR, 'ats.pkl')
with open(Path.joinpath(pt.CONFIGS_DIR, "data.yaml"), 'r') as stream:
settings = yaml.safe_load(stream)
settings = load_settings("data.yaml")
data = data_dto.Data(sc, ss, td, tc, ats)
screenings = get_screenings(data, settings)
......
#!/usr/bin/env python
from numpy.lib import utils
from tools import file_reader, file_writer, preprocessor
from utility.settings import load_settings
from pathlib import Path
import pandas as pd
import numpy as np
......@@ -8,12 +10,10 @@ import yaml
def main():
df = file_reader.read_pickle(pt.INTERIM_DATA_DIR, 'ats.pkl')
settings = load_settings("data.yaml")
with open(Path.joinpath(pt.CONFIGS_DIR, "data.yaml"), 'r') as stream:
settings = yaml.safe_load(stream)
ats_delimiter = settings['ats_delimiter']
df['DevISOClass'] = df['DevISOClass'].apply(lambda x: x[:ats_delimiter]) # limit ats class
ats_iso_length = settings['ats_iso_length']
df['DevISOClass'] = df['DevISOClass'].apply(lambda x: x[:ats_iso_length]) # limit ats iso length
df = df[['CitizenId', 'BirthYear', 'Gender', 'LendDate', 'ReturnDate', 'DevISOClass']]
df = df.fillna(df.LendDate.max()) # replace invalid return dates with latest obs lend date
df = df.loc[df['ReturnDate'] >= df['LendDate']] # return date must same or later than lend date
......@@ -24,9 +24,8 @@ def main():
mask_first = ~df.duplicated(subset=subset_cols, keep='first')
mask_last = ~df.duplicated(subset=subset_cols, keep='last')
hu_first = df[mask_first].loc[:, subset_cols + ['LendDate']]
hu_last = df[mask_last].loc[:, ['CitizenId', 'BirthYear',
'Gender', 'DevISOClass',
'ReturnDate']]
hu_last = df[mask_last].loc[:, ['CitizenId', 'BirthYear', 'Gender',
'DevISOClass', 'ReturnDate']]
merged = pd.merge(hu_first, hu_last, on=subset_cols)[['CitizenId', 'BirthYear',
'Gender', 'DevISOClass',
'LendDate', 'ReturnDate']]
......
#!/usr/bin/env python
from pathlib import Path
from tools import file_reader
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
import tensorflow as tf
import numpy as np
import pickle
from tools import file_reader, file_writer
import datetime
import paths as pt
LOAD_PREPROCESSED_DATA = False
PREPROCESSED_SEQUENCE_NAME = 10
ZERO_PAD_SEQUENCES = True
RUN_SMALLER_SAMPLE_SIZE = True
SEQ_LENGTH = 10
OVERWRITE_LOCAL_PREPROCESSED_SEQUENCES = False
# Get data
def load_data_and_mapping():
raw_df = file_reader.read_pickle(pt.INTERIM_DATA_DIR, 'ats.pkl').reset_index(drop=True)
mapping = file_reader.read_csv(pt.REFERENCES_DIR, 'ats.csv', converters={'ats_id': str})
return raw_df, mapping
# Shorten iso class
def shorten_iso(raw_df, mapping):
df = raw_df.copy()
df['DevISOClass'] = df['DevISOClass'].apply(lambda x: x[:6])
df = df.dropna(subset=['CitizenId'])
mapping_dict = dict(mapping.values)
df = df.replace(to_replace=mapping_dict)
df.sort_values(by='LendDate', inplace=True)
digit_values = df['DevISOClass'].str.isdigit()
df.drop(df[digit_values].index, inplace=True)
return df
# Remove citizens with only 1 record
def remove_single_records(df, seq_length):
counts = df.CitizenId.value_counts()
df = df[~df['CitizenId'].isin(counts[counts < seq_length + 1].index)]
return df
# Create list of sequence
def create_sequence_list(df, seq_length, column_name='DevISOClass'):
df = remove_single_records(df, seq_length)
X = []
y = []
print(f'shape: {df.shape}')
for citizenId in df.CitizenId.unique():
temp = df.loc[df.CitizenId==citizenId]
seq = tf.keras.preprocessing.sequence.TimeseriesGenerator(temp[column_name].values,
temp[column_name].values,
length=seq_length,
batch_size=1,
shuffle=False)
[((X.append(x[0][0].tolist())), y.append(x[1])) for x in seq]
return X, y
# Create list of sequence with zero-padding
def create_sequence_list_pad(df, seq_length, column_name='DevISOClass'):
X = []
y = []
for temp_seq in range(seq_length):
temp_seq = temp_seq + 1
temp_df = remove_single_records(df, temp_seq)
for i, citizenId in enumerate(temp_df.CitizenId.unique()):
# Below if is to not create a sequence when it's the first sample
if i == 0:
continue
temp = temp_df.loc[temp_df.CitizenId==citizenId]
seq = tf.keras.preprocessing.sequence.TimeseriesGenerator(temp[column_name].values,
temp[column_name].values,
length=temp_seq,
batch_size=1,
shuffle=False)
for x in seq:
to_append = np.pad(x[0][0], (seq_length-temp_seq, 0), 'constant', constant_values=(0))
X.append(to_append)
y.append(x[1])
return X, y
def make_strings_categorical(string_list, mapping_to_numbers={}):
y_true = np.zeros((len(string_list)))
for i, raw_label in enumerate(string_list):
if raw_label not in mapping_to_numbers:
mapping_to_numbers[raw_label] = len(mapping_to_numbers)
#y_true[i] = mapping_to_numbers[raw_label] +1 ??
mapping_to_numbers['ZeroPadding'] = 0
return mapping_to_numbers
def convert_str_to_int(word, mapping):
return mapping[word]
def load_pickle(seq_length):
X = file_reader.read_pickle(pt.INTERIM_DATA_DIR, f'X_lstm_{seq_length}.pkl')
y = file_reader.read_pickle(pt.INTERIM_DATA_DIR, f'y_lstm_{seq_length}.pkl')
return X, y
def save_pickle(data, file_name):
file_writer.write_pickle(data, pt.INTERIM_DATA_DIR, file_name)
def topKacc_5(Y_true, Y_pred):
return tf.keras.metrics.top_k_categorical_accuracy(Y_true, Y_pred, k=5)
def topKacc_4(Y_true, Y_pred):
return tf.keras.metrics.top_k_categorical_accuracy(Y_true, Y_pred, k=4)
def topKacc_3(Y_true, Y_pred):
return tf.keras.metrics.top_k_categorical_accuracy(Y_true, Y_pred, k=3)
def topKacc_2(Y_true, Y_pred):
return tf.keras.metrics.top_k_categorical_accuracy(Y_true, Y_pred, k=2)
def main():
raw_df, mapping = load_data_and_mapping()
if RUN_SMALLER_SAMPLE_SIZE:
raw_df = raw_df.head(10000)
counts = raw_df.DevISOClass.value_counts()
# raw_df = raw_df.loc[raw_df.DevISOClass.isin(counts.index[counts > seq_length])]
df = shorten_iso(raw_df, mapping)
string_to_int_mapping = make_strings_categorical(df.DevISOClass.values)
vocabulary_size = len(string_to_int_mapping)+1
df['DevISOClassCategorical'] = np.vectorize(convert_str_to_int)(df['DevISOClass'], string_to_int_mapping)
# Create sequences
# Load already processed sequences to save time
if LOAD_PREPROCESSED_DATA:
X, y = load_pickle(PREPROCESSED_SEQUENCE_NAME)
elif ZERO_PAD_SEQUENCES:
X, y = create_sequence_list_pad(df[['DevISOClassCategorical', 'CitizenId']],
SEQ_LENGTH, column_name='DevISOClassCategorical')
else:
X, y = create_sequence_list(df, SEQ_LENGTH, column_name='DevISOClassCategorical')
# Store sequences locally in pickle-file
if OVERWRITE_LOCAL_PREPROCESSED_SEQUENCES:
save_pickle(X, filename=f'X_{SEQ_LENGTH}.pkl')
save_pickle(y, filename=f'y_{SEQ_LENGTH}.pkl')
# Split data into train and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7)
train_targets = to_categorical(y_train, num_classes=vocabulary_size)