Commit 9ec83b9f authored by Jonathan Juhl's avatar Jonathan Juhl
Browse files

corrected log file

parent ecfe932b
import tensorflow as tf
import numpy as np
import umap
from os.path import join
from trainer_sortem import Trainer
from mrc_loader_sortem import mrc_loader
class GAN_NERF():
......@@ -10,7 +10,7 @@ class GAN_NERF():
self.args = args
dic = {32:1,64:2,128:3,256:4,512:5}
self.predict_steps = int(np.ceil(args['number_particles']/(args['num_gpus']*args['batch_size'])))
self.dic = dic
l = np.asarray([32,64,128,256,512])
self.args['resize'] = l[np.argmin(np.abs(l-self.args['size']))]
l_list = []
......@@ -52,16 +52,19 @@ class GAN_NERF():
predict_generator = mrc_loader(args).pred_generate()
output_generator = mrc_loader(args).pred_generate()
if args['num_gpus'] > 1:
strategy = tf.distribute.MirroredStrategy(devices= gpu_list )
self.generator = strategy.experimental_distribute_dataset( generator)
self.generator_pred = strategy.experimental_distribute_dataset( predict_generator )
self.output_generator = strategy.experimental_distribute_dataset( output_generator )
else:
strategy = tf.distribute.OneDeviceStrategy(device=gpu_list[0])
self.generator = strategy.experimental_distribute_dataset( generator )
self.generator_pred = strategy.experimental_distribute_dataset( predict_generator )
self.output_generator = strategy.experimental_distribute_dataset( output_generator )
args['strategy'] = strategy
self.trainer = Trainer(args)
......@@ -72,7 +75,7 @@ class GAN_NERF():
print('Begin training: ', '-' * 60)
current_step = self.trainer.step_variable
gen = iter(self.generator)
pred = iter(self.generator_pred)
for i in range(int(current_step)):
# this starts the data recording at where it left off
# this is to prevent when continuation of training the model does not use the same data
......@@ -101,27 +104,52 @@ class GAN_NERF():
self.trainer.distributed_training_step(params)
if (i % self.args['record']) == 0:
if self.args['num_gpus'] == 1:
self.single_device_model_maker.model_maker()
features = []
current_shape = params['image'].numpy().shape[1]
for kk in range(int(np.ceil(self.args['umap_t_size']/self.args['batch_size']))):
data = next(pred)
features.append(self.args['strategy'].run(self.trainer.get_features,
args=(data,params['alpha'],self.trainer.Encoder[int(params['index'])],current_shape)))
self.trainer.write_summaries(features)
else:
self.multi_device_model_maker.model_maker()
self.trainer.write_summaries()
data = next(pred)
features = []
for kk in range(int(np.ceil(self.args['umap_t_size']/(self.num_gpus*self.args['batch_size'])))):
features.append(self.args['strategy'].run(self.trainer.get_features,
args=(data,params['alpha'],self.trainer.Encoder[int(params['index'])],current_shape)).reduce())
self.trainer.write_summaries(features)
if (i % self.args['save_model']) == 0:
self.trainer.save_checkpoint()
self.trainer.save_best_model()
def over_cluster(self):
self.trainer.load_best_model()
trainer.sparse_water_sheed_algorithm()
def predict(self):
pred = iter(self.output_generator)
output_vectors = []
if not isfile(join(self.args['results'],'final_featur_vectors.npy')):
if self.args['num_gpus'] > 1:
for kk in range(int(np.ceil(self.args['depth']/self.args['num_gpus']*self.args['batch_size']))):
data = next(pred)
output_vectors.append(self.args['strategy'].run(self.trainer.get_features,args=(data,params['alpha'],self.trainer.Encoder[self.dic[self.args['resize']]],current_shape)).reduce())
else:
for kk in range(int(np.ceil(self.args['depth']/self.args['batch_size']))):
data = next(pred)
output_vectors.append(self.args['strategy'].run(self.trainer.get_features,args=(data,params['alpha'],self.trainer.Encoder[self.dic[self.args['resize']]],current_shape)))
np.save(join(self.args['results'],'final_featur_vectors.npy'))
labels,umap_output,collect_centers = pred_umap(args,feature_vector)
if not isfile(join(self.args['results'],'final_labels.npy')):
np.save(join(self.args['results'],'final_labels.npy'))
np.save(join(self.args['results'],'final_umap_output.npy'))
np.save(join(self.args['results'],'final_collect_centers.npy'))
self.trainer.load_best_model()
#self.trainer.model_maker()
bools = isfile(join(self.args['results'],'over_cluster.npy'))
if bools:
labels = np.load(join(self.args['results'],'over_cluster.npy'))
clusters = []
for i in range(self.predict_steps):
image = next(self.generator_pred )
......
......@@ -33,7 +33,7 @@ class ED_Maker(Model):
self.final_layer = final_layer
self.fromRGB.reverse()
self.start = self.img_size_to_layer[shape]
self.flatten = Flatten()
self.lays = layers
def call(self,input,alpha):
......@@ -51,7 +51,7 @@ class ED_Maker(Model):
x = layer(x)
out = tf.squeeze(self.final_layer(x))
out = self.flatten(self.final_layer(x))
......@@ -61,7 +61,8 @@ class ED_Maker(Model):
v = []
for i in self.lays[:self.start ]:
v+= i.trainable_variables
return v+self.fromRGB[self.start-1].trainable_variables+self.fromRGB[ int(self.start) ].trainable_variables
k = self.final_layer.trainable_variables
return v+k+self.fromRGB[self.start-1].trainable_variables+self.fromRGB[ int(self.start) ].trainable_variables
class ResidualConvBlock(Layer):
def __init__(self,inplanes, planes, kernel_size=3, stride=1, downsample=False, groups=1):
......@@ -167,7 +168,7 @@ class Ray_maker(Layer):
s = tf.shape(coordinates)[1]
powered = tf.expand_dims(tf.pow(tf.cast(2.0,dtype=coordinates.dtype),tf.cast(tf.range(self.init_L),dtype=coordinates.dtype)),axis=0)
out = tf.matmul(coordinates,powered,coordinates.dtype)
out = tf.matmul(coordinates,powered,transpose_a=True)
out = tf.reshape(out,[-1,s,self.init_L*3])
......@@ -212,9 +213,9 @@ class Ray_maker(Layer):
a,b = tf.split(z,2,axis=-1)
a = tf.squeeze(a)
b = tf.squeeze(b)
fourie = tf.complex(tf.cast(a,tf.float32),tf.cast(b,tf.float32))
return tf.squeeze(fourie)
return a,b
class Noise_Maker(Model):
......
......@@ -64,10 +64,8 @@ class mrc_loader:
image = tf.io.decode_raw(ins,tf.float32)
image = tf.reshape(image,[self.df_keys['size'],self.df_keys['size'],1])
image = tf.image.per_image_standardization(image)
minima = tf.reduce_min(image)
tmp = image+minima
image = tmp/tf.reduce_max(tmp)
return image
image = image-tf.reduce_min(image)
return image/tf.reduce_max(image)
data = tf.data.FixedLengthRecordDataset(self.df_keys['mrc_paths'],self.df_keys['bpr'],num_parallel_reads=self.df_keys['num_cpus'], header_bytes=1024).map(pred_image,self.df_keys['num_cpus']).batch(self.df_keys['batch_size']).repeat()
return data
......@@ -75,3 +73,25 @@ class mrc_loader:
# def generator_model(self,input_vectors,):
class Grid_Maker:
def __init__(self,kwargs,means):
self.size = kwargs['size']
x = tf.linspace(0.0,0.5,int(self.size/2)+1)
y = tf.linspace(-0.5,0.5,self.size)
z = tf.linspace(-0.5,0.5,self.size)
X,Y,Z = tf.meshgrid(x,y,z)
self.X = X
self.Y = Y
self.Z = Z
self.mean_size = means.shape[0]
def grid_generator(self):
x_slice = tf.from_tensor_slices(self.X)
y_slice = tf.from_tensor_slices(self.Y)
z_slice = tf.from_tensor_slices(self.Z)
return tf.data.Dataset.zip(x_slice,y_slice,z_slice).repeat(self.mean_size).batch(self.kwargs['m_batch_size'])
import tensorflow as tf
from tensorflow.keras import mixed_precision
from models_sortem import Ray_maker,ResidualConvBlock,AdapterBlock,ED_Maker,Noise_Maker,ResidualUpBlock,UpAdapterBlock
from utils_sortem import Discriminator_Loss,Generator_Loss,Encoder_Loss,DiversityLoss,Eval_CTF,Make_Grids,Poisson_Measure
from utils_sortem import gradient_penalty,Discriminator_Loss,Generator_Loss,Encoder_Loss,DiversityLoss,Eval_CTF,Make_Grids,Poisson_Measure,make_umap,Make_2D_Transform,get_parameters
from os.path import join
import matplotlib.pyplot as plt
import numpy as np
import umap
import mrcfile
from tensorflow.keras.layers import Conv2D
from pathlib import Path
from os.path import isfile
......@@ -19,39 +17,43 @@ class Trainer:
self.discriminator_loss = []
self.noise_loss_list = []
self.gen_loss = []
self.kl_loss = []
self.noise_loss = []
self.regulizor_loss = []
self.enc_loss = []
if isfile(join(args['results'],'discriminator_loss')):
self.discriminator_loss_list += np.load(join(args['results'],'discriminator_loss')).tolist()
else:
self.discriminator_loss_list = []
if isfile(join(args['results'],'gen_loss')):
self.gen_loss_list += np.load(join(args['results'],'gen_loss')).tolist()
self.gen_loss += np.load(join(args['results'],'gen_loss')).tolist()
else:
self.gen_loss_list = []
if isfile(join(args['results'],'noise_loss')):
self.noise_loss_list += np.load(join(args['results'],'noise_loss')).tolist()
else:
self.noise_loss_list = []
if isfile(join(args['results'],'regulizor_loss')):
self.regulizor_loss_list += np.load(join(args['results'],'regulizor_loss')).tolist()
if isfile(join(args['results'],'enc_loss')):
self.enc_loss_list += np.load(join(args['results'],'enc_loss')).tolist()
else:
self.regulizor_loss_list = []
self.enc_loss_list = []
self.args = args
if self.args['f16']: # convert models to mixed precision
if self.args['f16']: # require the mixed precision policy to be used.
self.dtype_enc = tf.float16
mixed_precision.set_global_policy('mixed_float16')
else:
self.dtype_enc = tf.float32
# initialize the output layers of the model
fin_encode = Conv2D(2*10+6,2,strides=2,padding='SAME')
fin_decode = Conv2D(1,2,strides=2,padding='SAME')
fin_noise = Conv2D(1,1,padding='SAME')
self.make_grids = Make_Grids()
self.poisson_noise = Poisson_Measure()
# the noise model with unknown signal/noise ratio
self.make_2d_transforms = [Make_2D_Transform(32),
Make_2D_Transform(64),
Make_2D_Transform(128),
Make_2D_Transform(256),
Make_2D_Transform(512),
]
self.noise = [
ResidualUpBlock(400, 400, downsample=True), # 16x16 -> 8x8
......@@ -65,7 +67,7 @@ class Trainer:
]
# the adapter when the model increases resolution
self.noise_adapter = [
UpAdapterBlock(400),
UpAdapterBlock(400),
......@@ -153,8 +155,8 @@ class Trainer:
self.Encoder += [ED_Maker(layers=self.encoder,fromRGB=self.encoder_adapter,final_layer=fin,shape=512)]
self.Discriminator += [ED_Maker(layers=self.discriminator,fromRGB=self.discriminator_adapter,final_layer=fin_decode,shape=512)]
self.Noise += [Noise_Maker(layers=self.noise,fromRGB=self.noise_adapter,final_layer=fin,shape=512)]
self.current_gen_loss = 0
self.previous_gen_loss = 10**6
self.gen_op = tf.keras.optimizers.Adam(learning_rate=args['lr_g']) # optimizer for GAN step # (learning_rate=self.lr_linear_decay_generator
self.decoder_op = tf.keras.optimizers.Adam(learning_rate=args['lr_d']) # optimizer for # learning_rate=self.lr_linear_decay_encoder
self.encoder_op = tf.keras.optimizers.Adam(learning_rate=args['lr_e'])
......@@ -171,7 +173,7 @@ class Trainer:
'Generator': Ray_maker(),
'Noise': self.Noise,
'gen_op': self.gen_op,
'decoder_op':self.decoder_op,
'disc_op':self.decoder_op,
'encoder_op':self.encoder_op,
}
self.step_variable = self.vd['steps'].numpy()
......@@ -183,11 +185,10 @@ class Trainer:
if not args['noise']:
self.noise_loss_list = self.noise_loss_list[:self.step_variable]
# the metrics
self.metrics_dictionary = {tf.keras.metrics.Mean('gen_loss', dtype=self.dtype_enc),
tf.keras.metrics.Mean('dis_loss', dtype=self.dtype_enc),
tf.keras.metrics.Mean('encoder_loss', dtype=self.dtype_enc),
tf.keras.metrics.Mean('noise_diversity_loss', dtype=self.dtype_enc)}
# the metrics to record the average of loss over time to the tensorboard session
self.metrics_dictionary = {'gen_loss':tf.keras.metrics.Mean('gen_loss', dtype=self.dtype_enc),
'dis_loss':tf.keras.metrics.Mean('dis_loss', dtype=self.dtype_enc),
'enc_loss':tf.keras.metrics.Mean('encoder_loss', dtype=self.dtype_enc)}
......@@ -195,19 +196,19 @@ class Trainer:
self.encoder_loss = Encoder_Loss(reduction=tf.keras.losses.Reduction.SUM)
self.generator_loss = Generator_Loss(reduction=tf.keras.losses.Reduction.SUM)
if not args['noise']:
self.max_steps = np.sum(self.args['steps'])
if args['noise']:
self.diversity_loss = DiversityLoss(reduction=tf.keras.losses.Reduction.SUM)
with self.args['strategy'].scope():
pass
#self.train_summary_writer = tf.summary.create_file_writer(self.args['tmp'])
#self.ckpt = tf.train.Checkpoint(**self.vd)
#self.manager = tf.train.CheckpointManager(self.ckpt ,self.args['model'], max_to_keep=3)
#self.best_current_model = tf.train.CheckpointManager(self.ckpt ,self.args['best_model'], max_to_keep=1)
#self.ckpt.restore(self.manager.latest_checkpoint)
self.train_summary_writer = tf.summary.create_file_writer(self.args['tmp'])
self.ckpt = tf.train.Checkpoint(**self.vd)
self.manager = tf.train.CheckpointManager(self.ckpt ,self.args['model'], max_to_keep=3)
......@@ -222,161 +223,150 @@ class Trainer:
if self.args['f16']:
gradients = optimizer.get_unscaled_gradients(gradients)
optimizer.apply_gradients(zip(gradients,variables))
def save_best_model(self):
if self.current_pixel_loss < self.previous_pixel_loss:
self.best_current_model.save()
self.previous_pixel_loss = self.current_pixel_loss
def save_checkpoint(self):
# save all weights from vd dictionary
self.manager.save()
def restore_checkpoint(self):
# restore from the last checkpoint. The checkpoint is the validation interval step
self.ckpt.restore(self.manager.latest_checkpoint)
def __record_loss__(self,loss,metrics_dic,labels):
# put the loss into the summary writer
def __record_loss__(self,loss,metrics_dic):
for i in metrics_dic.keys():
metrics_dic[i](loss[i])
@tf.function(jit_compile=True)
def __get_views__(self,means,labels):
high_res_1 = tf.linspace(-1.0,1.0,128)
high_res_2 = tf.linspace(-1.0,1.0,int(128/2)+1)
def __get_views__(self,means):
# create a top and size view grid by flipping the xyz, zyx axis.
high_res_1 = tf.linspace(-0.5,0.5,self.args['size'])
high_res_2 = tf.linspace(-0.5,0.5,int(self.args['size']/2)+1)
X,Y = tf.meshgrid(high_res_1,high_res_2)
X = tf.reshape(X,[-1])
Y = tf.reshape(Y,[-1])
Z = tf.zeros(shape=tf.shape(Y))
side_coordinates = tf.tile(tf.expand_dims(tf.stack([X,Y,Z],axis=1),axis=1),[1,tf.shape(means)[0],1])
top_coordinates = tf.tile(tf.stack([X,Z,Y],axis=1),[tf.shape(means)[1],1])
side_coordinates = tf.tile(tf.stack([X,Z,Y],axis=1),[tf.shape(means)[1],1])
view_x = []
image_top = self.vd['Generator'](means,top_coordinates)
image_side =self.vd['Generator'](means,side_coordinates)
image_top = self.vd['Generator'](self.args['size'],means,tf.roll(side_coordinates,1,axis=-1))
image_side =self.vd['Generator'](self.args['size'],means,side_coordinates)
return tf.signal.irfft2d(image_top), tf.signal.irfft2d(image_side)
@tf.function(jit_compile=True)
def __get_samples__(self,image):
parameters = self.vd['Encoder'](image)
feature_vector = parameters['mean']+parameters['variance']
return feature_vector
def hide_frame(self,ax):
# fode the output frames
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.set_axis_off()
def make_figure(self,pre,image,labels):
l = np.unique(labels)[1:]
fig, axs = plt.subplots(len(l), figsize=(self.args['frames'],len(l)+1))
stack = []
for index,i in enumerate(np.split(image,l.shape[0],axis=0)):
i = np.squeeze(i)
self.hide_frame(axs[index])
image_classes = np.concatenate(np.split(i,i.shape[0],axis=0),axis=-1)
image_classes = np.squeeze(image_classes)
axs[index].imshow(image_classes,cmap='gray')
axs[index].set_title('class %i' %index)
plt.savefig('%s_views_steps_%i' %(pre,int(self.vd['steps'])))
def full_model_reconstruction(self,mean,grid,shape):
slices = self.vd['Generator'](mean,grid,self.args['size'])
return slices
def write_summaries(self,feature_vectors):
means = self.make_umap(args,feature_vectors)
labels,mapped,means = make_umap(self.args,feature_vectors)
np.save(join(self.args['results'],'labels_step_%i.npy' %int(self.vd['steps'].numpy())),labels)
np.save(join(self.args['results'],'umap_coordinates_step_%i.npy' %int(self.vd['steps'].numpy())),mapped)
np.save(join(self.args['results'],'means_%i.npy' %int(self.vd['steps'].numpy())),means)
top_view,side_view = self.__get_views__(means)
self.make_figure('top',top_view,labels)
self.make_figure('side',side_view,labels)
new_dictionary = {}
fig, axs = plt.subplots(2, 4, figsize=(9, 3),share_x=True,share_y=True)
for keys, values in zip(self.metrics_dic.keys(),self.metrics_dic.values()):
step = int(self.vd['steps'].numpy())
for keys, values in zip(self.metrics_dictionary.keys(),self.metrics_dictionary.values()):
new_dictionary[keys] = values.result().numpy()
values.reset_states()
def hide_frame(ax):
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
fig, axs = plt.subplots(2, 4,sharex=True)
fig.tight_layout()
fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor='none', top=False, bottom=False, left=False, right=False)
plt.grid(False)
plt.xlabel("Step %i / %i" %(self.step_variable.numpy,self.args['max_steps']))
plt.xlabel("Step %i / %i" %(step,self.max_steps))
plt.ylabel("Loss")
if self.args['verbose']:
tf.summary.scalar("NMI", new_dictionary['NMI'], step=self.vd['step_variable'])
self.current_pixel_loss = new_dictionary['reconstruction_loss']
step=self.step_variable.numpy()
side,top = self.__single_device_model_maker__()
with self.train_summary_writer.as_default():
tf.summary.image("top_view",tf.expand_dims(top_view,axis=-1), max_outputs=int(self.args['num_parts']), step=self.step_variable.numpy())
tf.summary.image("side_view",tf.expand_dims(side_view,axis=-1),max_outputs=int(self.args['num_parts']), step=self.step_variable.numpy())
if 'discriminator_loss' in new_dictionary.keys():
tf.summary.scalar('discriminator_loss',new_dictionary['discriminator_loss'], step=step)
self.discriminator_loss.append(new_dictionary['discriminator_loss'])
np.save(join(self.args['results'],'discriminator_loss'),np.asarray(self.discriminator_loss))
axs[0,3].plot(x,y)
axs[0,3].title.set_text('Discriminator error')
axs[0,3].set_xticks([])
hide_frame(axs[0,3])
if 'gen_loss' in new_dictionary.keys():
np.save(join(self.args['results'],'gen_loss'),np.asarray(self.gen_loss))
tf.summary.image("top_view",tf.expand_dims(top_view,axis=-1), max_outputs=top_view.shape[0], step=step)
tf.summary.image("side_view",tf.expand_dims(side_view,axis=-1),max_outputs=side_view.shape[0], step=step)
tf.summary.scalar('gen_loss',new_dictionary['gen_loss'], step=step)
self.gen_loss.append(new_dictionary['gen_loss'])
axs[0,2].plot(x,y)
axs[0,2].title.set_text('Generator error')
axs[0,2].set_xticks([])
hide_frame(axs[0,2])
if 'KL_loss' in new_dictionary.keys():
tf.summary.scalar('KL_loss',new_dictionary['KL_loss'], step=step)
self.kl_loss.append(new_dictionary['KL_loss'])
np.save(join(self.args['results'],'KL_loss'),np.asarray(self.KL_loss))
axs[1,0].plot(x,y )
axs[1,0].title.set_text('KL error')
self.discriminator_loss_list.append(new_dictionary['gen_loss'])
np.save(join(self.args['results'],'gen_loss'),np.asarray(self.discriminator_loss_list))
disc = np.asarray( self.discriminator_loss_list)
axs[0,0].plot(np.arange(step+1),disc)
axs[0,0].title.set_text('Discriminator error')
axs[0,0].set_xticks([])
self.hide_frame(axs[0,0])
np.save(join(self.args['results'],'dis_loss'),np.asarray(self.gen_loss))
tf.summary.scalar('dis_loss',new_dictionary['dis_loss'], step=step)
self.gen_loss_list.append(new_dictionary['dis_loss'])
gen = np.asarray(self.gen_loss_list)
axs[0,1].plot(np.arange(step+1),gen)
axs[0,1].title.set_text('Discriminator error')
axs[0,1].set_xticks([])
self.hide_frame(axs[0,1])
tf.summary.scalar('enc_loss',new_dictionary['enc_loss'], step=step)
self.enc_loss_list.append(new_dictionary['enc_loss'])
np.save(join(self.args['results'],'enc_loss'),np.asarray(self.enc_loss_list))
enc = np.asarray(self.enc_loss_list)
axs[1,0].plot(np.arange(step+1),enc)
axs[1,0].title.set_text('Encoder Loss')
axs[1,0].set_xticks([])
hide_frame(axs[1,0])
if 'trans_loss' in new_dictionary.keys():
tf.summary.scalar('trans_loss',new_dictionary['trans_loss'], step=step)
self.trans_loss.append(new_dictionary['trans_loss'])
np.save(join(self.args['results'],'trans_loss'),np.asarray(self.trans_loss))
axs[1,1].plot(x,y)
axs[1,1].title.set_text('Translational error')
axs[1,1].set_xticks([])
hide_frame(axs[1,1])
if 'feature_loss' in new_dictionary.keys():
tf.summary.scalar('feature_loss',new_dictionary['feature_loss'], step=step)
self.feature_loss.append(new_dictionary['feature_loss'])
np.save(join(self.args['results'],'feature_loss'),np.asarray(self.feature_loss))
axs[1,2].plot(x,y )
axs[1,2].title.set_text('Feature error')
axs[1,2].set_xticks([])
hide_frame(axs[1,2])
if 'noise_loss' in new_dictionary.keys():
tf.summary.scalar('noise_loss',new_dictionary['noise_loss'], step=step)
self.noise_loss.append(new_dictionary['noise_loss'])
np.save(join(self.args['results'],'noise_loss'),np.asarray(self.noise_loss))
axs[1,3].plot(x,y )
axs[1,3].title.set_text('Noise error')
axs[1,3].set_xticks([])
hide_frame(axs[1,3])
self.hide_frame(axs[1,0])
np.save(join(self.args['results'],'umap.npy'),umap)
axs[1,1].plot(umap[:,0],umap[:,1])
axs[1,1].title.set_text('UMAP error')
self.hide_frame(axs[1,1])
self.train_summary_writer.flush()
fig.savefig(join(self.args['results'],'losses_%i.png') %self.step_variable.numpy())
fig.savefig(join(self.args['results'],'losses_%i.png') %step)
fig.clf()
fig, axs = plt.subplots(1, 3, figsize=(9, 3))
fig.suptitle('angels and translations step %i' %self.step_variable.numpy(), share_y = True)
axs[0].hist2d(np.asarray(self.angels['angels_2']).flatten(),np.asarray(self.angels['angels_1']).flatten(),bins=(300, 300), cmap=plt.cm.jet)
axs[0].colorbar()
axs[1].hist2d(np.asarray(self.angels['angels_2']).flatten(),np.asarray(self.angels['angels_3']).flatten(),bins=(300, 300), cmap=plt.cm.jet)
axs[1].colorbar()
axs[2].hist2d(np.asarray(self.angels['angels_1']).flatten(),np.asarray(self.angels['angels_3']).flatten(),bins=(300, 300), cmap=plt.cm.jet)
axs[2].colorbar()
axs[0].title.set_text('angle 2-1')
axs[1].title.set_text('angle 2-3')
axs[2].title.set_text('angle 1-3')
plt.hist2d(np.asarray(self.angels['angels_1']).flatten(),np.asarray(self.angels['angels_2']).flatten(),bins=(300, 300), cmap=plt.cm.jet)
plt.colorbar()
fig.savefig(join(self.args['results'],'histograms_angels_%i.png') %step)
fig.clf()
np.save(join(self.args['results'],'angle_1_step_%s' %step),np.asarray(self.angels['angels_1']).flatten())
np.save(join(self.args['results'],'angle_2_step_%s'%step),np.asarray(self.angels['angels_2']).flatten())
np.save(join(self.args['results'],'angle_3_step_%s'%step),np.asarray(self.angels['angels_3']).flatten())
if not self.args['no_translations']:
fig, axs = plt.subplots(1, 1, figsize=(9, 3))