Skip to content
Snippets Groups Projects
Commit 1e4a69f3 authored by Malthe Kjær Bisbo's avatar Malthe Kjær Bisbo
Browse files

added scale_reg and template_structure to gpr_state for GOFEE restart

parent f80a044e
No related branches found
No related tags found
No related merge requests found
......@@ -876,7 +876,8 @@ class GOFEE():
kernel=gpr_state[1],
prior=gpr_state[2],
n_restarts_optimizer=gpr_state[3],
template_structure=training_structures[0])
scale_reg= gpr_state[4],
template_structure=gpr_state[5])
self.save_structures(training_structures)
if self.population_method != 'clustering':
self.population.gpr = self.gpr
......
......@@ -333,6 +333,7 @@ class GPR():
self.prior = prior
self.n_restarts_optimizer = n_restarts_optimizer
self.template_structure = template_structure
self.scale_reg = scale_reg
self.memory = gpr_memory(self.descriptor, self.prior, **kwargs)
......@@ -693,7 +694,8 @@ class GPR():
deepcopy(self.kernel),
deepcopy(self.prior),
deepcopy(self.n_restarts_optimizer),
deepcopy(self.scale_reg)]
deepcopy(self.scale_reg),
deepcopy(self.template_structure)]
def get_local_model(self, ref, Nmax, Nmax_force=None, Nforces=None, save_folder=None, n_restarts_optimizer=1):
mem_new = self.memory.get_trimmed_memory(ref, Nmax, Nmax_force, Nforces, save_folder)
......
No preview for this file type
......@@ -130,7 +130,7 @@ def logrank(n1,d1,t1,n2,d2,t2):
return st.norm.sf(abs(Z))*2
# This is the real function of interest
def survival_stats(times,events,alpha=0.95,sigma=5000,show_plot=True,legend_outside=False,save=True,get_hazard=True,labels=[], save_dir='stats'):
def survival_stats(times,events,alpha=0.95,sigma=5000,show_plot=True,legend_outside=False,save=True,get_hazard=True,labels=[], colors=None, linestyles=None, save_dir='stats'):
"""This function calculateds a number of statistics that may beof interest
inputs:
......@@ -252,14 +252,20 @@ def survival_stats(times,events,alpha=0.95,sigma=5000,show_plot=True,legend_outs
if show_plot:# make plots
labels += range(len(labels),n_inputs)
f, ax = subplots(1,1, figsize=(7+3*legend_outside,5))
colors = ['b','r','g','y','c','m']
if colors is None:
colors = ['b','r','g','y','c','m']
max_time = 0
for i in range(n_inputs):
color_i = colors[i%len(colors)]
if linestyles is None:
linestyle_i = None
else:
linestyle_i = linestyles[i]
try:
ax.fill_between(CDF[i][0], CDF[i][2], CDF[i][3], step='post', facecolor=colors[i%len(colors)], alpha=0.1)
ax.fill_between(CDF[i][0], CDF[i][2], CDF[i][3], step='post', facecolor=color_i, alpha=0.1)
except:
ax.fill_between(CDF[i][0], CDF[i][2], CDF[i][3], facecolor=colors[i%len(colors)], alpha=0.1)
ax.step(CDF[i][0], CDF[i][1],where='post',c=colors[i%len(colors)], label=labels[i])
ax.fill_between(CDF[i][0], CDF[i][2], CDF[i][3], facecolor=color_i, alpha=0.1)
ax.step(CDF[i][0], CDF[i][1],where='post',c=color_i, linestyle=linestyle_i, label=labels[i])
#ax.plot(CDF[i][0][censoring[i]],CDF[i][1][censoring[i]],marker='+',c='k')
if CDF[i][0][-1] > max_time:
max_time = KM[i][0][-1]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment