Commit ef58f4a4 authored by thecml's avatar thecml
Browse files

updated models for cv

parent c3785199
Pipeline #98685 passed with stage
in 4 minutes and 59 seconds
#!/usr/bin/env python
import numpy as np
import paths as pt
from typing import List
from tools import file_writer, preprocessor, data_loader
from tools.classifiers import KnnClassifier, SvmClassifier, LrClassifier
from tools.classifiers import XgbClassifier, RfClassifier, MlpClassifier
from sklearn.model_selection import cross_validate
from tools import data_loader
from sklearn.preprocessing import LabelEncoder
from utility.settings import load_settings
from utility.metrics import compute_mean, compute_std
from sksurv.ensemble import RandomSurvivalForest
from sklearn.model_selection import StratifiedKFold, KFold
from sklearn.model_selection import KFold
from sksurv.metrics import concordance_index_censored
import pandas as pd
def main():
data_settings = load_settings(pt.CONFIGS_DIR, "data.yaml")
......
......@@ -14,15 +14,12 @@ def main():
td = loader.load_training_done(pt.PATHS_2021[2],
pt.PATHS_2021[4],
pt.RAW_DATA_DIR_2021)
sc = loader.load_screening_content(pt.PATHS_2021[2],
pt.PATHS_2021[4],
pt.RAW_DATA_DIR_2021)
tc = loader.load_training_cancelled(pt.PATHS_2021[2],
pt.PATHS_2021[4],
pt.RAW_DATA_DIR_2021)
ss = loader.load_status_set(pt.PATHS_2021[2],
pt.PATHS_2021[4],
pt.RAW_DATA_DIR_2021)
......
......@@ -62,24 +62,19 @@ class SvmClassifier(BaseClassifer):
"""Support vector machine classifier."""
def make_model(self):
return SVC(random_state=0,
class_weight="balanced",
probability=True)
class LrClassifier(BaseClassifer):
"""Logistic regression classifier."""
def make_model(self):
return LogisticRegression(max_iter=1000,
class_weight="balanced",
random_state=0)
class XgbClassifier(BaseClassifer):
"""XGBoost classifier."""
def make_model(self):
neg, pos = np.bincount(self.y)
scale_pos_weight = neg / pos
params = {"n_estimators": 400,
"learning_rate": 0.1,
"scale_pos_weight": scale_pos_weight,
"objective": "binary:logistic",
"random_state": 0,
"use_label_encoder": False,
......@@ -90,7 +85,6 @@ class RfClassifier(BaseClassifer):
"""Random Forest classifier."""
def make_model(self):
return RandomForestClassifier(n_estimators=800,
class_weight="balanced",
random_state=0)
class MlpClassifier(BaseClassifer):
......@@ -118,7 +112,5 @@ class MlpClassifier(BaseClassifer):
optimizer="Adam",
metrics=metrics)
return model
neg, pos = np.bincount(self.y)
class_weight = preprocessor.get_class_weight(neg, pos)
return KerasClassifier(make_keras_model, epochs=20, batch_size=32,
class_weight=class_weight, verbose=False)
return KerasClassifier(make_keras_model, epochs=20,
batch_size=64, verbose=False)
......@@ -115,7 +115,7 @@ def write_cv_plot(means: List, stds: List, metric: str,
fig.suptitle(f"{os.path.splitext(title)[0]} {subtitle}")
plt.setp(axs[-1, :], xlabel='Seed')
plt.setp(axs[:, 0], ylabel=metric)
plt.savefig(outfile, dpi=300, bbox_inches = "tight")
plt.savefig(outfile, dpi=300, bbox_inches="tight")
def write_roc_curve(y_true: np.ndarray, results: List,
title: str, subtitle: str, outfile: BytesIO):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment