Commit dfde4461 authored by Jonathan Juhl's avatar Jonathan Juhl
Browse files

added different functionailities

parent 9f7bfaf7
This diff is collapsed.
...@@ -18,6 +18,17 @@ class control_flow: ...@@ -18,6 +18,17 @@ class control_flow:
def __init__(self,args): def __init__(self,args):
self.args = args self.args = args
required = [ join(self.args['tmp'],'depth.npy'),
join(self.args['tmp'],'names.npy')]
ctf = [join(self.args['tmp'],'electron_volts.npy'),
join(self.args['tmp'],'spherical_abberation.npy'),
join(self.args['tmp'],'amplitude_contrast.npy'),
join(self.args['tmp'],'ctf_params.npy')]
dir_path = os.path.dirname(os.path.realpath(__file__)) dir_path = os.path.dirname(os.path.realpath(__file__))
if not isinstance(self.args['star'], list) and isdir(star): if not isinstance(self.args['star'], list) and isdir(star):
...@@ -32,9 +43,12 @@ class control_flow: ...@@ -32,9 +43,12 @@ class control_flow:
self.args['star'] = [join(dir_path,i) for i in self.args['star']] self.args['star'] = [join(dir_path,i) for i in self.args['star']]
if not isfile(join(self.args['tmp'],'depth.npy')): if not all(isfile(i) for i in required):
depth,mrc_paths,voltage = self.get_star_file_parameters() depth,mrc_paths = self.get_star_file_parameters()
elif self.args['ctf'] and not all(isfile(i) for i in ctf):
depth,mrc_paths = self.get_star_file_parameters()
else: else:
depth = np.load(join(self.args['tmp'],'depth.npy')) depth = np.load(join(self.args['tmp'],'depth.npy'))
mrc_paths = np.load(join(self.args['tmp'],'names.npy')) mrc_paths = np.load(join(self.args['tmp'],'names.npy'))
...@@ -46,6 +60,7 @@ class control_flow: ...@@ -46,6 +60,7 @@ class control_flow:
self.args['size'] = length self.args['size'] = length
self.args['number_particles'] = depth self.args['number_particles'] = depth
self.args['bpr'] = bytes_pr_record self.args['bpr'] = bytes_pr_record
if args['ctf']: if args['ctf']:
self.args['kvolts'] = tf.constant(np.load(join(self.args['tmp'],'electron_volts.npy')),tf.float32) self.args['kvolts'] = tf.constant(np.load(join(self.args['tmp'],'electron_volts.npy')),tf.float32)
self.args['sphe_ab'] = tf.constant(np.load(join(self.args['tmp'],'spherical_abberation.npy')),tf.float32) self.args['sphe_ab'] = tf.constant(np.load(join(self.args['tmp'],'spherical_abberation.npy')),tf.float32)
...@@ -57,7 +72,7 @@ class control_flow: ...@@ -57,7 +72,7 @@ class control_flow:
# print(self.args );exit() # print(self.args );exit()
def get_star_file_parameters(self): def get_star_file_parameters(self):
counter = [] counter = []
f = 0 f = 0
names = [] names = []
...@@ -66,7 +81,7 @@ class control_flow: ...@@ -66,7 +81,7 @@ class control_flow:
labels_list = [] labels_list = []
angular_list = [] angular_list = []
count = 0 count = 0
coordinates = []
for z in star_files: for z in star_files:
c = z[::-1].split('/',1)[1][::-1] c = z[::-1].split('/',1)[1][::-1]
with open(z, newline='') as csvfile: with open(z, newline='') as csvfile:
...@@ -104,6 +119,7 @@ class control_flow: ...@@ -104,6 +119,7 @@ class control_flow:
header = take_this header = take_this
name = header.index('_rlnImageName') name = header.index('_rlnImageName')
if self.args['ctf']: if self.args['ctf']:
voltage = take_that.index('_rlnVoltage') voltage = take_that.index('_rlnVoltage')
defocusU = header.index('_rlnDefocusU') defocusU = header.index('_rlnDefocusU')
...@@ -121,30 +137,19 @@ class control_flow: ...@@ -121,30 +137,19 @@ class control_flow:
except: except:
self.args['verbose'] = False self.args['verbose'] = False
print("the --verbose true cannot be run as _rlnClassNumber is missing ") print("the --verbose true cannot be run as _rlnClassNumber is missing ")
if self.args['ang_error']:
try:
class_num = header.index('_rlnClassNumber')
angel_rot = header.index('_rlnAngleRot')
angel_tilt = header.index('_rlnAngleTilt')
angle_psi = header.index('_rlnAnglePsi')
origin_x = header.index('_rlnOriginX')
origin_y = header.index('_rlnOriginY')
except:
self.args['ang_error'] = False
print("the --verbose true cannot be run as angular reconstruction is missing ")
for row in reader: for row in reader:
if self.args['ctf']: if self.args['ctf']:
if len(take_that) == len(row.split()) and counter ==2: if len(take_that) == len(s) and counter ==2:
abberation_d = float(row.split()[abberation]) abberation_d = float(s[abberation])
amp_contrast_d = float(row.split()[amp_contrast]) amp_contrast_d = float(s[amp_contrast])
V = float(row.split()[voltage]) V = float(s[voltage])
electron_volts = (1.23*10**3)/np.sqrt(V*(V*10**(-7)*1.96+1)) electron_volts = (1.23*10**3)/np.sqrt(V*(V*10**(-7)*1.96+1))
counter = 0 counter = 0
s = row.split()
if len(header)== len(row.split()): if len(header)== len(row.split()):
f+=1 f+=1
...@@ -156,21 +161,19 @@ class control_flow: ...@@ -156,21 +161,19 @@ class control_flow:
else: else:
names.append(join(c,current_name)) names.append(join(c,current_name))
if self.args['verbose']: if self.args['verbose']:
labels_list.append(int(row.split()[class_num])) labels_list.append(int(s[class_num]))
if self.args['ang_error']:
angular_list.append([float(row.split()[angel_rot]),float(row.split()[angel_tilt]),float(row.split()[angle_psi]),float(row.split()[origin_x]),float(row.split()[origin_y])])
if counter == 1: if counter == 1:
if self.args['ctf']: if self.args['ctf']:
V = float(row.split()[voltage]) V = float(s[voltage])
abberation_d = float(row.split()[abberation]) abberation_d = float(s[abberation])
amp_contrast_d = float(row.split()[amp_contrast]) amp_contrast_d = float(s[amp_contrast])
counter = 0 counter = 0
if self.args['ctf']: if self.args['ctf']:
ctf_params.append([float(row.split()[phase_shift]),float(row.split()[defocusU]),float(row.split()[defocusV]),float(row.split()[defocusAngle])]) ctf_params.append([float(s[phase_shift]),float(s[defocusU]),float(s[defocusV]),float(s[defocusAngle])])
current_id = row.split()[name].split('@')[0] current_id = row.split()[name].split('@')[0]
np.save(join(self.args['tmp'],'depth.npy'),f) np.save(join(self.args['tmp'],'depth.npy'),f)
...@@ -182,9 +185,7 @@ class control_flow: ...@@ -182,9 +185,7 @@ class control_flow:
np.save(join(self.args['tmp'],'ctf_params.npy'),np.asarray(ctf_params)) np.save(join(self.args['tmp'],'ctf_params.npy'),np.asarray(ctf_params))
if self.args['verbose']: if self.args['verbose']:
np.save(join(self.args['tmp'],'labels.npy'),np.asarray(labels_list)) np.save(join(self.args['tmp'],'labels.npy'),np.asarray(labels_list))
if self.args['ang_error']: return f,np.unique(names)
np.save(join(self.args['tmp'],'angular_error.npy'),np.asarray(angular_list))
return f,np.unique(names),V
......
...@@ -9,15 +9,18 @@ class GAN_NERF(): ...@@ -9,15 +9,18 @@ class GAN_NERF():
self.args = args self.args = args
self.predict_steps = int(np.ceil(args['number_particles']/(args['num_gpus']*args['p_batch_size']))) self.predict_steps = int(np.ceil(args['number_particles']/(args['num_gpus']*args['p_batch_size'])))
steps = 0 steps = 0
if args['s1'] <= args['size']: stages_list = []
steps+=args['s1'] l = [32,64,128,256]
if args['s2'] <= args['size']: c = [args['s_1'],args['s_2'],args['s_3'],args['s_4']]
steps+=args['s2'] for en,i in enumerate(l):
if args['s3'] <= args['size']: if i <= args['size']:
steps+=args['s3'] steps+=c[en]
if args['s3'] <= args['size']: stages_list.append(en*np.ones(c[en]))
steps+=args['s3'] else:
steps += args['top_off'] stages_list.append(np.concatenate([en*np.ones(c[en]),4*np.ones(self.args['top_off'])],axis=0))
break
self.stages = np.concatenate(stages_list)
self.max_steps = steps self.max_steps = steps
gpu_list = [] gpu_list = []
...@@ -58,41 +61,51 @@ class GAN_NERF(): ...@@ -58,41 +61,51 @@ class GAN_NERF():
args['strategy'] = strategy args['strategy'] = strategy
args['max_steps'] = self.max_steps
self.trainer = Trainer(args) self.trainer = Trainer(args)
self.train() #self.train()
self.predict() self.predict()
def train(self): def train(self):
print('Begin training: ', '-' * 60) print('Begin training: ', '-' * 60)
current_step = self.trainer.step_variable current_step = self.trainer.step_variable
for i in range(int(current_step)):
data = next(self.generator)
for i in range(self.max_steps -current_step): # continue where you came from for i in range(self.max_steps -current_step): # continue where you came from
data = next(self.generator) print("data step %i" %i )
data = next(self.generator)
if self.args['num_gpus'] == 1:
self.trainer.single_device_train_step(data,self.stages[i])
else:
self.trainer.distributed_training_step(data,self.stages[i])
if (i % self.args['vi']) == 0:
self.trainer.write_summaries()
if (i % self.args['save_model']) == 0:
self.trainer.save_checkpoint()
self.trainer.save_best_model()
if i % self.args['movie_int']:
if self.args['num_gpus'] == 1: if self.args['num_gpus'] == 1:
self.trainer.single_device_train_step(data) self.single_device_model_maker.model_maker()
"""else: else:
self.trainer.dis_train_step(data) self.multi_device_model_maker.model_maker()
if i % self.args['val_step']:
self.trainer.metric_step()
if i % self.args['model_step']:
self.trainer.model_step()
if i % self.args['save_step']:
self.trainer.save()"""
def predict(self): def predict(self):
self.trainer.load_best_model()
self.trainer.model_step() #self.trainer.model_maker()
clusters = [] clusters = []
for i in range(self.predict_steps): for i in range(self.predict_steps):
if self.def_dict['num_gpus'] == 1: image = next(self.generator_pred )
self.trainer.predict_step(i)
if self.args['num_gpus'] == 1:
current = self.trainer.single_device_prediction_step(i)
else: else:
self.trainer.dis_predict_step(data) current = self.trainer.dis_predict_step(i)
clusters.append(current)
np.save(join(self.args['results'],'class_labels.npy'),np.asarray(clusters).flatten()[:self.args['number_particles']])
\ No newline at end of file
...@@ -30,6 +30,8 @@ def main(): ...@@ -30,6 +30,8 @@ def main():
parser.add_argument('--movie_int', type=int,default=500,help='validation interval where models at full size are printed out.') parser.add_argument('--movie_int', type=int,default=500,help='validation interval where models at full size are printed out.')
parser.add_argument('--save_model', type=int,default=100,help='validation interval where models at full size are printed out.')
parser.add_argument('--verbose',dest='verbose', action='store_true',help='se the performance of the model by including original class labels') parser.add_argument('--verbose',dest='verbose', action='store_true',help='se the performance of the model by including original class labels')
parser.add_argument('--num_parts',type=int,default=4,help='Number of gaussian components to use. (This is the maximum number)') parser.add_argument('--num_parts',type=int,default=4,help='Number of gaussian components to use. (This is the maximum number)')
...@@ -46,26 +48,72 @@ def main(): ...@@ -46,26 +48,72 @@ def main():
parser.add_argument('--noise', dest='noise',action='store_true',help='Use the noise generator for model. Set true or false boolean.') parser.add_argument('--noise', dest='noise',action='store_true',help='Use the noise generator for model. Set true or false boolean.')
parser.add_argument('--s_1_steps',type=int,default=5000,help='how many steps to generate the 32 x 32 model') parser.add_argument('--s_1_steps',type=int,default=10,help='how many steps to generate the 32 x 32 model')
parser.add_argument('--s_2_steps',type=int,default=4000,help='how many steps to generate the 64 x 64 model') parser.add_argument('--s_2_steps',type=int,default=2,help='how many steps to generate the 64 x 64 model')
parser.add_argument('--s_3_steps',type=int,default=3000,help='how many steps to generate the 128 x 128 model') parser.add_argument('--s_3_steps',type=int,default=5,help='how many steps to generate the 128 x 128 model')
parser.add_argument('--s_4_steps',type=int,default=2000,help='how many steps to generate the 256 x 256 model') parser.add_argument('--s_4_steps',type=int,default=6,help='how many steps to generate the 256 x 256 model')
parser.add_argument('--top_off',type=int,default=1000,help='steps to finish training at the speficied resolution') parser.add_argument('--top_off',type=int,default=6,help='steps to finish training at the speficied resolution')
parser.add_argument('--l_reg',type=float,default=0.01,help='the lambda regulization of the diversity score loss') parser.add_argument('--l_reg',type=float,default=0.01,help='the lambda regulization of the diversity score loss')
parser.add_argument('--feature_size',type=float,default=128,help='the input feature size') parser.add_argument('--feature_size',type=int,default=128,help='the input feature size')
parser.add_argument('--over_cluster', dest='over_cluster',action='store_true',default=False,help='Use CTF parameters for model.') parser.add_argument('--over_cluster', dest='over_cluster',action='store_true',default=False,help='Use CTF parameters for model.')
parser.add_argument('--ang_error', dest='ang_error',action='store_true',default=False,help='Use CTF parameters for model.') parser.add_argument('--dstep',type=int,default=5,help='How many frames over each axis the protein is made in the UMAP reduction')
parser.add_argument('--seg_mode', dest='seg_mode',action='store_true',default = False,help='decomposition of the image to its individuel parts. Can be used on datasets like the ribosome')
parser.add_argument('--noise_bg', dest='noise_bg',action='store_true',default = False,help='To use a noise background estimator to mask instead of the noise generator')
parser.add_argument('--no_gen', dest='no_gen',action='store_true',default = False,help='Using a 3D volumetric model instead of the generator')
parser.add_argument('--TD_mode', dest='TD_mode',action='store_true',default = False,help='If you wish to switch to 2D classification instead')
parser.add_argument('--Only_VAE', dest='Only_VAE',action='store_true',default = False,help='To only use VAE to perform the classification')
parser.add_argument('--m_batch_size',type=int,default=25,help='the batch size to make the 3D model')
parser.add_argument('--frames',type=int,default=36,help='number of movie frames')
parser.add_argument('--use_eulers',dest='use_eulers',action='store_true',help='if to use the standard euler rotation matrix instead')
parser.add_argument('--no_angle', dest='no_angle',action='store_true',default = False,help='Do not use any angles to do the classifcation')
args = parser.parse_args()
if isinstance(args.Only_VAE,bool):
print("perform 2D classification is a bool", isinstance(args.Only_VAE,bool))
else:
assert print("The 2D classification is not a bool")
if isinstance(args.TD_mode,bool):
print("perform 2D classification is a bool" , isinstance(args.TD_mode,bool))
else:
assert print("The 2D classification is not a bool")
if isinstance(args.no_gen,bool):
print("background is instance of: bool", isinstance(args.no_gen,bool))
else:
assert print("The no generation is not a bool")
if isinstance(args.noise_bg,bool):
print("background is instance of: bool", isinstance(args.noise_bg,bool))
else:
assert print("The segmentation is not a bool")
if isinstance(args.seg_mode,bool):
print("segmentation is instance of: bool", isinstance(args.seg_mode,bool))
else:
assert print("The segmentation is not a bool")
if isinstance(args.dstep,int) and args.dstep > 0:
print("dstep is instance of: int", isinstance(args.dstep,int), "and is: %i" %args.dstep)
else:
assert print("the dstep is not an integer or is less than 0")
args = parser.parse_args() if isinstance(args.movie_int,int) and args.movie_int > 0:
print("movie int is instance of: int", isinstance(args.movie_int,int), "and is: %i" %args.movie_int)
else:
assert print("the dstep is not an integer or is less than 0")
if isinstance(args.num_gpus,int) and args.num_gpus > 0: if isinstance(args.num_gpus,int) and args.num_gpus > 0:
...@@ -180,6 +228,30 @@ def main(): ...@@ -180,6 +228,30 @@ def main():
else: else:
assert print("the training steps is not a float") assert print("the training steps is not a float")
if isinstance(args.no_angle,bool):
print("the no angle is a bool")
else:
assert print("the no angle is not a bool")
if isinstance(args.frames,int) and args.frames > 0:
print("the number of frames is an int and is larger than zero")
else:
assert print("it is not a integer and is not larger than 0")
if isinstance(args.m_batch_size,int) and args.m_batch_size > 0:
print("the batch size is an int and is larger than zero")
else:
assert print("it is not a integer and is not larger than 0")
if isinstance(args.use_eulers,bool):
print("use euler angles is a bool")
else:
assert print("use euler angles is not a bool")
if not isdir(args.o): if not isdir(args.o):
mkdir(args.o) mkdir(args.o)
if not isdir(join(args.o,'tmp')): if not isdir(join(args.o,'tmp')):
...@@ -188,10 +260,13 @@ def main(): ...@@ -188,10 +260,13 @@ def main():
mkdir(join(args.o,'model')) mkdir(join(args.o,'model'))
if not isdir(join(args.o,'results')): if not isdir(join(args.o,'results')):
mkdir(join(args.o,'results')) mkdir(join(args.o,'results'))
if not isdir(join(args.o,'best_model')):
mkdir(join(args.o,'best_model'))
model_list = [] model_list = []
for i in range(args.num_parts): for i in range(args.num_parts):
if not isdir(join(args.o,'results')): if not isdir(join(join(args.o,'results'),'model_%i' %i)):
model_list.append(mkdir(join(join(args.o,'results'),'model_%i' %i))) mkdir(join(join(args.o,'results'),'model_%i' %i))
model_list.append(join(join(args.o,'results'),'model_%i' %i))
args_dic = {'num_gpus': args.num_gpus, args_dic = {'num_gpus': args.num_gpus,
'num_cpus': args.num_cpus, 'num_cpus': args.num_cpus,
...@@ -213,24 +288,35 @@ def main(): ...@@ -213,24 +288,35 @@ def main():
'noise': args.noise, 'noise': args.noise,
'tmp': join(args.o,'tmp'), 'tmp': join(args.o,'tmp'),
'model': join(args.o,'model'), 'model': join(args.o,'model'),
'best_model': join(args.o,'best_model'),
'results': join(args.o,'results'), 'results': join(args.o,'results'),
's1': args.s_1_steps, 'save_model': args.save_model,
's2': args.s_2_steps, 's_1': args.s_1_steps,
's3': args.s_3_steps, 's_2': args.s_2_steps,
's4': args.s_4_steps, 's_3': args.s_3_steps,
's_4': args.s_4_steps,
'top_off': args.top_off, 'top_off': args.top_off,
'l_reg': args.l_reg, 'l_reg': args.l_reg,
'feature_size': args.feature_size, 'feature_size': args.feature_size,
'over_cluster': args.over_cluster, 'over_cluster': args.over_cluster,
'models': model_list, 'models': model_list,
'ang_error': args.ang_error, 'movie_int': args.movie_int,
'dstep': args.dstep,
'Only_VAE': args.Only_VAE,
'2D_mode':args.TD_mode,
'noise_bg': args.noise_bg,
'seg_mode': args.seg_mode,
'no_gen': args.no_gen,
'm_batch_size': args.m_batch_size,
'frames': args.frames,
'no_angle': args.no_angle,
'use_eulers': args.use_eulers
} }
control_flow(args_dic) control_flow(args_dic)
......
...@@ -6,14 +6,33 @@ from tensorflow.keras.layers import Flatten,LeakyReLU,Conv2DTranspose, Dense,Con ...@@ -6,14 +6,33 @@ from tensorflow.keras.layers import Flatten,LeakyReLU,Conv2DTranspose, Dense,Con
from utils_sortem import CoordConv2D from utils_sortem import CoordConv2D
import tensorflow_addons as tfa import tensorflow_addons as tfa
from utils_sortem import SMM from utils_sortem import SMM
class Generator_Model(Model):
def __init__(self,size):
Model_Layer(size)
class SlotAttention_Model(Model):
def __init__(self,num_parts,feature_size):
super(Adapter,self).__init__()
self.slot_attn = SlotAttention(10, num_parts, feature_size, mlp_hidden_size=256)
def __call__(self,image):
x = self.slot_attn(image)
return x
class Adapter(Model): class Adapter(Model):
def __init__(self,in_channels): def __init__(self,in_channels):
super(Adapter,self).__init__() super(Adapter,self).__init__()
self.coord_conv_0 = CoordConv2D(1,in_channels) self.coord_conv_0 = Conv2D(in_channels,1)
def __call__(self,image): def __call__(self,image):
x = self.coord_conv_0(image) x = self.coord_conv_0(image)
return x return x
...@@ -29,7 +48,7 @@ class Stage(Model): ...@@ -29,7 +48,7 @@ class Stage(Model):
self.avg_pool_0 = AveragePooling2D(2) self.avg_pool_0 = AveragePooling2D(2)
self.avg_pool_identity = AveragePooling2D(2) self.avg_pool_identity = AveragePooling2D(2)
def __call__(self,image): def call(self,image):
x = self.coord_conv_0(image) x = self.coord_conv_0(image)
x = self.leaky_relu(x) x = self.leaky_relu(x)
...@@ -136,21 +155,25 @@ class ray_maker(Model): ...@@ -136,21 +155,25 @@ class ray_maker(Model):
class SMM_Model(Model): class SMM_Model(Model):
def __init__(self): def __init__(self,n_cluster, n_feature):
super(SMM_Model,self).__init__() super(SMM_Model,self).__init__()