Commit 15e22b17 authored by Jonathan Juhl's avatar Jonathan Juhl
Browse files

Delete plot_functions.py

parent ba800271
def plot_function(self,Models,y,logwriter,logfile,strategy,average_dis,distribution,l_list,ite,validate_interval,prefix):
image = tf.cast(image,self.precision)
features,feat = strategy.run(self.predict_encoder,args=(image,))
feat.numpy().astype(np.float32)
features.numpy().astype(np.float32)
classes,invers_counts = np.unique(features,axis=0,return_inverse=True)
self.gui_plot(int(ite/validate_interval),feat,np.asarray(l_list)[1:],np.asarray(average_dis),distribution,colors=y,prefix=prefix) # plot the data with label colors
if bool(self.verbose) > 0:
c = np.arange(classes.shape[0])[invers_counts]
acc = np.round(accuracy_score(y, c), 5) # accuracy
nmi = np.round(normalized_mutual_info_score(y, c), 5) # normalized mutual information
ari = np.round(adjusted_rand_score(y, c), 5) #
print('Iter %d: of %d acc=%.5f, nmi=%.5f, loss=%.5f, dis_avg=%.5f' % (ite,self.steps, acc, nmi, l_list[-1], average_dis[-1])) # print into console , specific statistics
logdict = dict(iter=ite, acc=acc, nmi=nmi, loss=l_list[-1], dis_avg=average_dis[-1]) # log it into logfile.
logwriter.writerow(logdict) # write row
logfile.flush()# flush into file to avoid memory consumption
else:
print('Iter %d: of %d ,loss=%.5f, dis_avg=%.5f' % (ite,self.steps, l_list[-1], average_dis[-1])) # print information
logdict = dict(iter=ite, loss=l_list[-1], dis_avg=average_dis[-1]) # log it into logfile.
logwriter.writerow(logdict)
def plot_gan(self,Models,logwriter,logfile,strategy,d_list,l_list,ite,validate_interval,prefix,generator):
image = tf.cast(image,self.precision)
models = strategy.run(generator,args=(Models,))
view_models = self.project_a_model(models)
self.gui_plot(int(ite/validate_interval),view_models,np.asarray(d_list)[1:],np.asarray(l_list)[1:],models,prefix=prefix) # plot the data with label colors
print('Iter %d: of %d loss_generator=%.5f, loss_descriminator=%.5f' % (ite,self.steps)) # print into console , specific statistics
logdict = dict(iter=ite, loss=l_list[-1]) # log it into logfile.
logwriter.writerow(logdict) # write row
logfile.flush()# flush into file to avoid memory consumption
def gui_plot(self,image_real,loss_generator,loss_discriminator,model_similarity,prefix=None):
fig, ((ax1, ax2),(ax3, ax4)) = plt.subplots(nrows=2,ncols=2)
s = image_real.shape[1]
top_r = np.reshape(image_real,(2*s,2*s))
top_f = np.reshape(image_fake,(2*s,2*s))
pca_all = TSNE(2).fit_transform(features)
plt.tight_layout()
ax1.imshow(top_r)
ax1.set_title("Projections from Orthogonal Vectors")
ax2.plot(np.arange(loss_generator.shape[0])*self.validate_interval, loss_generator, 'go--', linewidth=1, markersize=8)
ax3.plot(np.arange(loss_discriminator.shape[0])*self.validate_interval, loss_discriminator, 'ro--', linewidth=1, markersize=8)
ax2.set_ylabel('loss', fontsize=10)
ax2.set_title("Generative Loss")
ax3.imshow(top_f)
ax3.set_title("Discriminative Loss")
ax4.imshow(np.dot(model_similarity,model_similarity.T))
ax4.set_ylabel('Cross correlation', fontsize=10)
ax4.set_title("Model similarity")
plt.savefig(join(self.results,'%s_output.png' %prefix))
plt.clf()
plt.close('all')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment