Commit 54423fbc authored by Jonathan Juhl's avatar Jonathan Juhl
Browse files

now with save checkpoints

parent 9bb7d614
......@@ -9,13 +9,13 @@ from os.path import isdir,join
def main():
parser = argparse.ArgumentParser(description='Run sortinator.')
parser.add_argument('--gpu_id',type=int, default= 0, help='GPU ID')
parser.add_argument('--num_cpus',type=int,default = 8,help='The maximum allowed cpus to use for preprocessing data and Kmeans clustering')
parser.add_argument('--gpu_id',type=int, default= 0, help='GPU ID. The ID of the GPU to execute the operations on. ')
parser.add_argument('--num_cpus',type=int,default = 8,help='The maximum allowed cpus to use for preprocessing data (image resize and normalization')
parser.add_argument('--star', type=str, nargs='+',
help='list of path to the star files, wild cards are accepted. The star file must refer to the .mrc files')
parser.add_argument('--batch_size', type=int,default=[100,75,50,20,10], nargs='+',
help='deep learning model training batch')
help='deep learning model training batch size for each image scale')
parser.add_argument('--o', type=str,default='./results',
help='output directory')
......@@ -23,13 +23,13 @@ def main():
parser.add_argument('--f16', dest='f16',action='store_true',
help='Apply Tensor core acceleration to training and inference, requires compute capability of 10.0 or higher.')
parser.add_argument('--save_model', type=int,default=5,help='validation interval where models at full size are printed out.')
parser.add_argument('--save_model', type=int,default=500,help='validation interval where models at full size are printed out.')
parser.add_argument('--lr_g',type=float,default=[10**(-5),0.5*10**(-5),10**(-6),0.5*10**(-6),10**(-7),0.5*10**(-7)], nargs='+',help='The start learning rate of the generator')
parser.add_argument('--lr_g',type=float,default=[10**(-5),0.5*10**(-5),10**(-6),0.5*10**(-6),10**(-7),0.5*10**(-7)], nargs='+',help='The staircase learning rates of the generator')
parser.add_argument('--lr_d',type=float,default=[10**(-4),0.5*10**(-4),10**(-5),0.5*10**(-5),10**(-6),0.5*10**(-6)], nargs='+',help='The start learning rate of the descriminator')
parser.add_argument('--lr_d',type=float,default=[10**(-4),0.5*10**(-4),10**(-5),0.5*10**(-5),10**(-6),0.5*10**(-6)], nargs='+',help='The staircase learning rates of the discriminator')
parser.add_argument('--lr_e',type=float,default=[10**(-4),0.5*10**(-4),10**(-5),0.5*10**(-5),10**(-6),0.5*10**(-6)], nargs='+',help='The start learning rate of the encoder')
parser.add_argument('--lr_e',type=float,default=[10**(-4),0.5*10**(-4),10**(-5),0.5*10**(-5),10**(-6),0.5*10**(-6)], nargs='+',help='The staircase learning rates of the encoder')
parser.add_argument('--ctf', dest='ctf',action='store_true',default=False,help='Use CTF parameters for model.')
......@@ -39,19 +39,17 @@ def main():
parser.add_argument('--l_reg',type=float,default=0.01,help='the lambda regulization of the diversity score loss if the noise generator is active')
parser.add_argument('--m_batch_size',type=int,default=25,help='the batch size to make the 3D model')
parser.add_argument('--frames',type=int,default=4,help='number of models to generate from each cluster')
parser.add_argument('--umap_p_size',type=int,default=100,help='The UMAP size to train the umap model. It is trained on the CPU in parallel')
parser.add_argument('--umap_p_size',type=int,default=10000,help='The number of feature vectors to use for training Umap')
parser.add_argument('--umap_t_size',type=int,default=100,help='The UMAP size')
parser.add_argument('--umap_t_size',type=int,default=10000,help='The number of feature vectors to use for intermediate evaluation of clusters in the umap algorithm')
parser.add_argument('--neighbours',type=int,default=30,help='number of neighbours in the graph creation algorithm')
parser.add_argument('--t_res',type=int,default=None,choices=[32,64,128,256,512],help='number of neighbours in the graph creation algorithm')
parser.add_argument('--t_res',type=int,default=None,choices=[32,64,128,256,512],help='The maximum resolution to train the model on')
parser.add_argument('--minimum_size',type=int,default=500,help='the minimum size before its considered an actual cluster, anything else less is considered noise and will be discarded')
parser.add_argument('--minimum_size',type=int,default=500,help='the minimum size before its considered an actual cluster, anything else less is considered noise')
args = parser.parse_args()
......@@ -86,7 +84,6 @@ def main():
'save_model': args.save_model,
'steps': args.steps,
'l_reg': args.l_reg,
'm_batch_size': args.m_batch_size,
'frames': args.frames,
'umap_p_size': args.umap_p_size,
'umap_t_size': args.umap_t_size,
......
......@@ -19,15 +19,16 @@ class Trainer:
self.discriminator_loss_list = []
self.gen_loss_list = []
self.enc_loss_list = []
if isfile(join(args['results'],'discriminator_loss')):
if isfile(join(args['results'],'dis_loss.npy')):
self.discriminator_loss_list += np.load(join(args['results'],'dis_loss.npy')).tolist()
else:
self.discriminator_loss_list = []
if isfile(join(args['results'],'gen_loss')):
if isfile(join(args['results'],'gen_loss.npy')):
self.gen_loss_list += np.load(join(args['results'],'gen_loss.npy')).tolist()
else:
self.gen_loss_list = []
if isfile(join(args['results'],'enc_loss')):
if isfile(join(args['results'],'enc_loss.npy')):
self.enc_loss_list += np.load(join(args['results'],'enc_loss.npy')).tolist()
else:
self.enc_loss_list = []
......@@ -343,7 +344,7 @@ class Trainer:
axs[0,0].title.set_text('Discriminator error')
axs[0,0].set_xticks([])
self.hide_frame(axs[0,0])
print(self.discriminator_loss_list)
np.save(join(self.args['results'],'dis_loss'),np.asarray(self.discriminator_loss_list))
tf.summary.scalar('dis_loss',new_dictionary['dis_loss'], step=step)
gen = np.asarray(self.discriminator_loss_list)
......@@ -593,11 +594,9 @@ class Trainer:
params.update(self.__Generator_Step__(im_2,ctf_2,alpha,Discriminator,Encoder,noise,transform))
params.update(self.__Discriminator_Step__(im_3,ctf_3,alpha,Discriminator,Encoder,noise,transform))
self.discriminator_loss_list = [params['dis_loss']]
self.gen_loss_list = [params['gen_loss']]
self.enc_loss_list = [params['enc_loss']]
self.discriminator_loss_list += [params['dis_loss'].numpy()]
self.gen_loss_list += [params['gen_loss'].numpy()]
self.enc_loss_list += [params['enc_loss'].numpy()]
return params
@tf.function
def generate(self,im_size,grid_part,means):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment