Commit dfde4461 authored by Jonathan Juhl's avatar Jonathan Juhl
Browse files

added different functionailities

parent 9f7bfaf7
This diff is collapsed.
......@@ -18,6 +18,17 @@ class control_flow:
def __init__(self,args):
self.args = args
required = [ join(self.args['tmp'],'depth.npy'),
join(self.args['tmp'],'names.npy')]
ctf = [join(self.args['tmp'],'electron_volts.npy'),
join(self.args['tmp'],'spherical_abberation.npy'),
join(self.args['tmp'],'amplitude_contrast.npy'),
join(self.args['tmp'],'ctf_params.npy')]
dir_path = os.path.dirname(os.path.realpath(__file__))
if not isinstance(self.args['star'], list) and isdir(star):
......@@ -32,9 +43,12 @@ class control_flow:
self.args['star'] = [join(dir_path,i) for i in self.args['star']]
if not isfile(join(self.args['tmp'],'depth.npy')):
depth,mrc_paths,voltage = self.get_star_file_parameters()
if not all(isfile(i) for i in required):
depth,mrc_paths = self.get_star_file_parameters()
elif self.args['ctf'] and not all(isfile(i) for i in ctf):
depth,mrc_paths = self.get_star_file_parameters()
else:
depth = np.load(join(self.args['tmp'],'depth.npy'))
mrc_paths = np.load(join(self.args['tmp'],'names.npy'))
......@@ -46,6 +60,7 @@ class control_flow:
self.args['size'] = length
self.args['number_particles'] = depth
self.args['bpr'] = bytes_pr_record
if args['ctf']:
self.args['kvolts'] = tf.constant(np.load(join(self.args['tmp'],'electron_volts.npy')),tf.float32)
self.args['sphe_ab'] = tf.constant(np.load(join(self.args['tmp'],'spherical_abberation.npy')),tf.float32)
......@@ -57,7 +72,7 @@ class control_flow:
# print(self.args );exit()
def get_star_file_parameters(self):
counter = []
f = 0
names = []
......@@ -66,7 +81,7 @@ class control_flow:
labels_list = []
angular_list = []
count = 0
coordinates = []
for z in star_files:
c = z[::-1].split('/',1)[1][::-1]
with open(z, newline='') as csvfile:
......@@ -104,6 +119,7 @@ class control_flow:
header = take_this
name = header.index('_rlnImageName')
if self.args['ctf']:
voltage = take_that.index('_rlnVoltage')
defocusU = header.index('_rlnDefocusU')
......@@ -121,30 +137,19 @@ class control_flow:
except:
self.args['verbose'] = False
print("the --verbose true cannot be run as _rlnClassNumber is missing ")
if self.args['ang_error']:
try:
class_num = header.index('_rlnClassNumber')
angel_rot = header.index('_rlnAngleRot')
angel_tilt = header.index('_rlnAngleTilt')
angle_psi = header.index('_rlnAnglePsi')
origin_x = header.index('_rlnOriginX')
origin_y = header.index('_rlnOriginY')
except:
self.args['ang_error'] = False
print("the --verbose true cannot be run as angular reconstruction is missing ")
for row in reader:
if self.args['ctf']:
if len(take_that) == len(row.split()) and counter ==2:
if len(take_that) == len(s) and counter ==2:
abberation_d = float(row.split()[abberation])
amp_contrast_d = float(row.split()[amp_contrast])
V = float(row.split()[voltage])
abberation_d = float(s[abberation])
amp_contrast_d = float(s[amp_contrast])
V = float(s[voltage])
electron_volts = (1.23*10**3)/np.sqrt(V*(V*10**(-7)*1.96+1))
counter = 0
s = row.split()
if len(header)== len(row.split()):
f+=1
......@@ -156,21 +161,19 @@ class control_flow:
else:
names.append(join(c,current_name))
if self.args['verbose']:
labels_list.append(int(row.split()[class_num]))
labels_list.append(int(s[class_num]))
if self.args['ang_error']:
angular_list.append([float(row.split()[angel_rot]),float(row.split()[angel_tilt]),float(row.split()[angle_psi]),float(row.split()[origin_x]),float(row.split()[origin_y])])
if counter == 1:
if self.args['ctf']:
V = float(row.split()[voltage])
abberation_d = float(row.split()[abberation])
amp_contrast_d = float(row.split()[amp_contrast])
V = float(s[voltage])
abberation_d = float(s[abberation])
amp_contrast_d = float(s[amp_contrast])
counter = 0
if self.args['ctf']:
ctf_params.append([float(row.split()[phase_shift]),float(row.split()[defocusU]),float(row.split()[defocusV]),float(row.split()[defocusAngle])])
ctf_params.append([float(s[phase_shift]),float(s[defocusU]),float(s[defocusV]),float(s[defocusAngle])])
current_id = row.split()[name].split('@')[0]
np.save(join(self.args['tmp'],'depth.npy'),f)
......@@ -182,9 +185,7 @@ class control_flow:
np.save(join(self.args['tmp'],'ctf_params.npy'),np.asarray(ctf_params))
if self.args['verbose']:
np.save(join(self.args['tmp'],'labels.npy'),np.asarray(labels_list))
if self.args['ang_error']:
np.save(join(self.args['tmp'],'angular_error.npy'),np.asarray(angular_list))
return f,np.unique(names),V
return f,np.unique(names)
......
......@@ -9,15 +9,18 @@ class GAN_NERF():
self.args = args
self.predict_steps = int(np.ceil(args['number_particles']/(args['num_gpus']*args['p_batch_size'])))
steps = 0
if args['s1'] <= args['size']:
steps+=args['s1']
if args['s2'] <= args['size']:
steps+=args['s2']
if args['s3'] <= args['size']:
steps+=args['s3']
if args['s3'] <= args['size']:
steps+=args['s3']
steps += args['top_off']
stages_list = []
l = [32,64,128,256]
c = [args['s_1'],args['s_2'],args['s_3'],args['s_4']]
for en,i in enumerate(l):
if i <= args['size']:
steps+=c[en]
stages_list.append(en*np.ones(c[en]))
else:
stages_list.append(np.concatenate([en*np.ones(c[en]),4*np.ones(self.args['top_off'])],axis=0))
break
self.stages = np.concatenate(stages_list)
self.max_steps = steps
gpu_list = []
......@@ -58,41 +61,51 @@ class GAN_NERF():
args['strategy'] = strategy
args['max_steps'] = self.max_steps
self.trainer = Trainer(args)
self.train()
#self.train()
self.predict()
def train(self):
print('Begin training: ', '-' * 60)
current_step = self.trainer.step_variable
for i in range(int(current_step)):
data = next(self.generator)
for i in range(self.max_steps -current_step): # continue where you came from
data = next(self.generator)
print("data step %i" %i )
data = next(self.generator)
if self.args['num_gpus'] == 1:
self.trainer.single_device_train_step(data,self.stages[i])
else:
self.trainer.distributed_training_step(data,self.stages[i])
if (i % self.args['vi']) == 0:
self.trainer.write_summaries()
if (i % self.args['save_model']) == 0:
self.trainer.save_checkpoint()
self.trainer.save_best_model()
if i % self.args['movie_int']:
if self.args['num_gpus'] == 1:
self.trainer.single_device_train_step(data)
"""else:
self.trainer.dis_train_step(data)
if i % self.args['val_step']:
self.trainer.metric_step()
if i % self.args['model_step']:
self.trainer.model_step()
if i % self.args['save_step']:
self.trainer.save()"""
self.single_device_model_maker.model_maker()
else:
self.multi_device_model_maker.model_maker()
def predict(self):
self.trainer.model_step()
self.trainer.load_best_model()
#self.trainer.model_maker()
clusters = []
for i in range(self.predict_steps):
if self.def_dict['num_gpus'] == 1:
self.trainer.predict_step(i)
image = next(self.generator_pred )
if self.args['num_gpus'] == 1:
current = self.trainer.single_device_prediction_step(i)
else:
self.trainer.dis_predict_step(data)
current = self.trainer.dis_predict_step(i)
clusters.append(current)
np.save(join(self.args['results'],'class_labels.npy'),np.asarray(clusters).flatten()[:self.args['number_particles']])
\ No newline at end of file
......@@ -30,6 +30,8 @@ def main():
parser.add_argument('--movie_int', type=int,default=500,help='validation interval where models at full size are printed out.')
parser.add_argument('--save_model', type=int,default=100,help='validation interval where models at full size are printed out.')
parser.add_argument('--verbose',dest='verbose', action='store_true',help='se the performance of the model by including original class labels')
parser.add_argument('--num_parts',type=int,default=4,help='Number of gaussian components to use. (This is the maximum number)')
......@@ -46,26 +48,72 @@ def main():
parser.add_argument('--noise', dest='noise',action='store_true',help='Use the noise generator for model. Set true or false boolean.')
parser.add_argument('--s_1_steps',type=int,default=5000,help='how many steps to generate the 32 x 32 model')
parser.add_argument('--s_1_steps',type=int,default=10,help='how many steps to generate the 32 x 32 model')
parser.add_argument('--s_2_steps',type=int,default=4000,help='how many steps to generate the 64 x 64 model')
parser.add_argument('--s_2_steps',type=int,default=2,help='how many steps to generate the 64 x 64 model')
parser.add_argument('--s_3_steps',type=int,default=3000,help='how many steps to generate the 128 x 128 model')
parser.add_argument('--s_3_steps',type=int,default=5,help='how many steps to generate the 128 x 128 model')
parser.add_argument('--s_4_steps',type=int,default=2000,help='how many steps to generate the 256 x 256 model')
parser.add_argument('--s_4_steps',type=int,default=6,help='how many steps to generate the 256 x 256 model')
parser.add_argument('--top_off',type=int,default=1000,help='steps to finish training at the speficied resolution')
parser.add_argument('--top_off',type=int,default=6,help='steps to finish training at the speficied resolution')
parser.add_argument('--l_reg',type=float,default=0.01,help='the lambda regulization of the diversity score loss')
parser.add_argument('--feature_size',type=float,default=128,help='the input feature size')
parser.add_argument('--feature_size',type=int,default=128,help='the input feature size')
parser.add_argument('--over_cluster', dest='over_cluster',action='store_true',default=False,help='Use CTF parameters for model.')
parser.add_argument('--ang_error', dest='ang_error',action='store_true',default=False,help='Use CTF parameters for model.')
parser.add_argument('--dstep',type=int,default=5,help='How many frames over each axis the protein is made in the UMAP reduction')
parser.add_argument('--seg_mode', dest='seg_mode',action='store_true',default = False,help='decomposition of the image to its individuel parts. Can be used on datasets like the ribosome')
parser.add_argument('--noise_bg', dest='noise_bg',action='store_true',default = False,help='To use a noise background estimator to mask instead of the noise generator')
parser.add_argument('--no_gen', dest='no_gen',action='store_true',default = False,help='Using a 3D volumetric model instead of the generator')
parser.add_argument('--TD_mode', dest='TD_mode',action='store_true',default = False,help='If you wish to switch to 2D classification instead')
parser.add_argument('--Only_VAE', dest='Only_VAE',action='store_true',default = False,help='To only use VAE to perform the classification')
parser.add_argument('--m_batch_size',type=int,default=25,help='the batch size to make the 3D model')
parser.add_argument('--frames',type=int,default=36,help='number of movie frames')
parser.add_argument('--use_eulers',dest='use_eulers',action='store_true',help='if to use the standard euler rotation matrix instead')
parser.add_argument('--no_angle', dest='no_angle',action='store_true',default = False,help='Do not use any angles to do the classifcation')
args = parser.parse_args()
if isinstance(args.Only_VAE,bool):
print("perform 2D classification is a bool", isinstance(args.Only_VAE,bool))
else:
assert print("The 2D classification is not a bool")
if isinstance(args.TD_mode,bool):
print("perform 2D classification is a bool" , isinstance(args.TD_mode,bool))
else:
assert print("The 2D classification is not a bool")
if isinstance(args.no_gen,bool):
print("background is instance of: bool", isinstance(args.no_gen,bool))
else:
assert print("The no generation is not a bool")
if isinstance(args.noise_bg,bool):
print("background is instance of: bool", isinstance(args.noise_bg,bool))
else:
assert print("The segmentation is not a bool")
if isinstance(args.seg_mode,bool):
print("segmentation is instance of: bool", isinstance(args.seg_mode,bool))
else:
assert print("The segmentation is not a bool")
if isinstance(args.dstep,int) and args.dstep > 0:
print("dstep is instance of: int", isinstance(args.dstep,int), "and is: %i" %args.dstep)
else:
assert print("the dstep is not an integer or is less than 0")
args = parser.parse_args()
if isinstance(args.movie_int,int) and args.movie_int > 0:
print("movie int is instance of: int", isinstance(args.movie_int,int), "and is: %i" %args.movie_int)
else:
assert print("the dstep is not an integer or is less than 0")
if isinstance(args.num_gpus,int) and args.num_gpus > 0:
......@@ -180,6 +228,30 @@ def main():
else:
assert print("the training steps is not a float")
if isinstance(args.no_angle,bool):
print("the no angle is a bool")
else:
assert print("the no angle is not a bool")
if isinstance(args.frames,int) and args.frames > 0:
print("the number of frames is an int and is larger than zero")
else:
assert print("it is not a integer and is not larger than 0")
if isinstance(args.m_batch_size,int) and args.m_batch_size > 0:
print("the batch size is an int and is larger than zero")
else:
assert print("it is not a integer and is not larger than 0")
if isinstance(args.use_eulers,bool):
print("use euler angles is a bool")
else:
assert print("use euler angles is not a bool")
if not isdir(args.o):
mkdir(args.o)
if not isdir(join(args.o,'tmp')):
......@@ -188,10 +260,13 @@ def main():
mkdir(join(args.o,'model'))
if not isdir(join(args.o,'results')):
mkdir(join(args.o,'results'))
if not isdir(join(args.o,'best_model')):
mkdir(join(args.o,'best_model'))
model_list = []
for i in range(args.num_parts):
if not isdir(join(args.o,'results')):
model_list.append(mkdir(join(join(args.o,'results'),'model_%i' %i)))
if not isdir(join(join(args.o,'results'),'model_%i' %i)):
mkdir(join(join(args.o,'results'),'model_%i' %i))
model_list.append(join(join(args.o,'results'),'model_%i' %i))
args_dic = {'num_gpus': args.num_gpus,
'num_cpus': args.num_cpus,
......@@ -213,24 +288,35 @@ def main():
'noise': args.noise,
'tmp': join(args.o,'tmp'),
'model': join(args.o,'model'),
'best_model': join(args.o,'best_model'),
'results': join(args.o,'results'),
's1': args.s_1_steps,
's2': args.s_2_steps,
's3': args.s_3_steps,
's4': args.s_4_steps,
'save_model': args.save_model,
's_1': args.s_1_steps,
's_2': args.s_2_steps,
's_3': args.s_3_steps,
's_4': args.s_4_steps,
'top_off': args.top_off,
'l_reg': args.l_reg,
'feature_size': args.feature_size,
'over_cluster': args.over_cluster,
'models': model_list,
'ang_error': args.ang_error,
'movie_int': args.movie_int,
'dstep': args.dstep,
'Only_VAE': args.Only_VAE,
'2D_mode':args.TD_mode,
'noise_bg': args.noise_bg,
'seg_mode': args.seg_mode,
'no_gen': args.no_gen,
'm_batch_size': args.m_batch_size,
'frames': args.frames,
'no_angle': args.no_angle,
'use_eulers': args.use_eulers
}
control_flow(args_dic)
......
......@@ -6,14 +6,33 @@ from tensorflow.keras.layers import Flatten,LeakyReLU,Conv2DTranspose, Dense,Con
from utils_sortem import CoordConv2D
import tensorflow_addons as tfa
from utils_sortem import SMM
class Generator_Model(Model):
def __init__(self,size):
Model_Layer(size)
class SlotAttention_Model(Model):
def __init__(self,num_parts,feature_size):
super(Adapter,self).__init__()
self.slot_attn = SlotAttention(10, num_parts, feature_size, mlp_hidden_size=256)
def __call__(self,image):
x = self.slot_attn(image)
return x
class Adapter(Model):
def __init__(self,in_channels):
super(Adapter,self).__init__()
self.coord_conv_0 = CoordConv2D(1,in_channels)
self.coord_conv_0 = Conv2D(in_channels,1)
def __call__(self,image):
x = self.coord_conv_0(image)
return x
......@@ -29,7 +48,7 @@ class Stage(Model):
self.avg_pool_0 = AveragePooling2D(2)
self.avg_pool_identity = AveragePooling2D(2)
def __call__(self,image):
def call(self,image):
x = self.coord_conv_0(image)
x = self.leaky_relu(x)
......@@ -136,21 +155,25 @@ class ray_maker(Model):
class SMM_Model(Model):
def __init__(self):
def __init__(self,n_cluster, n_feature):
super(SMM_Model,self).__init__()
self.smm = SMM()
self.smm = SMM(n_cluster, n_feature)
def call(self,zeros):
z,z_log,z_var = self.smm(zeros)
def call(self,z, z_mu, z_sigma2_log):
z,z_log,z_var = self.smm(z, z_mu, z_sigma2_log)
return z,z_log,z_var
class Encoder(Model):
def __init__(self,feature_size,encoder=True,f16=False):
def __init__(self,feature_size,angel,encoder=True,f16=False):
super(Encoder,self).__init__()
self.rotation = rotation = 6
if angel:
self.rotation = 6
else:
self.rotation = 2
self.translations = translation = 2
self.encoder = encoder
self.prediction = 1
self.angel = angel
self.latent_vec = int(feature_size/2)
self.s_4 = Stage(400)
self.s_5 = Stage(400)
......@@ -168,9 +191,10 @@ class Encoder(Model):
variance = x[:,: self.latent_vec ]
mean = x[:,self.latent_vec:2*self.latent_vec]
rotations = x[:,2*self.latent_vec:(2*self.latent_vec+self.rotation)]
if self.angel:
rotations = x[:,2*self.latent_vec:(2*self.latent_vec+self.rotation)]
translations = x[:,-self.translations:]
translations = tf.nn.tanh(x[:,-self.translations:])
return variance,mean,rotations,translations
else:
......@@ -178,7 +202,7 @@ class Encoder(Model):
latent_vec = 2*self.latent_vec
feature_vec = x[:,: latent_vec]
rotations = x[:, latent_vec: latent_vec+self.rotation]
translations = x[:,-(self.translations+2):-2]
translations = tf.nn.tanh(x[:,-(self.translations+2):-2])
prediction = x[:,-self.prediction:]
return prediction,feature_vec,rotations,translations
......@@ -188,14 +212,14 @@ class Encoder(Model):
class ResBlock(Model):
def __init__(self, channels, stride=1):
super(ResBlock, self).__init__(name='ResBlock')
self.conv1 = Conv2DTranspose(channels, 3, stride, padding='same')
self.conv1 = Conv2DTranspose(channels, 3, 2, padding='same')
self.bn1 = BatchNormalization()
self.conv2 = Conv2D(channels, 3, padding='same')
self.bn2 = BatchNormalization()
self.relu = ReLU()
self.bn3 = BatchNormalization()
self.conv3 = Conv2D(channels, 1, stride)
self.conv3 = Conv2DTranspose(channels, 1, 2)
def call(self, x):
x1 = self.conv1(x)
......@@ -206,7 +230,7 @@ class ResBlock(Model):
x = self.conv3(x)
x = self.bn3(x)
x1 = Layers.add([x, x1])
x1 = x+x1
x1 = self.relu(x1)
return x1
......@@ -40,11 +40,10 @@ class mrc_loader:
for i in kwargs.keys():
self.df_keys[i] = kwargs[i]
width_list = []
alpha_list = []
stages_list = []
s_t1 = tf.linspace(0.0,1.0,self.df_keys['s_1'])
s_t2 = tf.linspace(0.0,1.0,self.df_keys['s_2'])
s_t3 = tf.linspace(0.0,1.0,self.df_keys['s_3'])
......@@ -52,51 +51,43 @@ class mrc_loader:
s_t5 = tf.linspace(1.0,1.0,self.df_keys['top_off'])
n_g = self.df_keys['num_gpus']
if self.df_keys['size'] >= 32:
alpha_list.append(s_t1)
stages_list.append(tf.zeros([self.df_keys['s_1']],dtype=tf.int32))
width_list.append(tf.ones([self.df_keys['s_1']*self.df_keys['batch_size']],dtype=tf.int32)*32)
width_list.append(tf.ones([self.df_keys['s_1']*n_g],dtype=tf.int32)*32)
else:
alpha_list.append(tf.concat([s_t1,s_t5],axis=0))
stages_list.append(tf.zeros([self.df_keys['s_1']],dtype=tf.int32))
width_list.append(tf.ones([(self.df_keys['s_2']+self.df_keys['top_off'])*self.df_keys['batch_size']],dtype=tf.int32)*self.df_keys['width'])
width_list.append(tf.ones([(self.df_keys['s_2']+self.df_keys['top_off'])*n_g],dtype=tf.int32)*self.df_keys['width'])
if self.df_keys['size'] >= 64:
alpha_list.append(s_t2)
stages_list.append(tf.ones([self.df_keys['s_2']],dtype=tf.int32)*2)
width_list.append(tf.ones([self.df_keys['s_2']*self.df_keys['batch_size']],dtype=tf.int32)*64)
width_list.append(tf.ones([self.df_keys['s_2']*n_g],dtype=tf.int32)*64)
else:
width_list.append(tf.ones([(self.df_keys['s_2']+self.df_keys['top_off'])*self.df_keys['batch_size']],dtype=tf.int32)*64)
stages_list.append(tf.ones([self.df_keys['s_2']],dtype=tf.int32)*2)
width_list.append(tf.ones([(self.df_keys['s_2']+self.df_keys['top_off'])*n_g],dtype=tf.int32)*64)
alpha_list.append(tf.concat([s_t2,s_t5],axis=0))
if self.df_keys['size'] >= 128:
alpha_list.append(s_t3)
width_list.append(tf.ones([self.df_keys['s_3']*self.df_keys['batch_size']],dtype=tf.int32)*128)
stages_list.append(tf.ones([self.df_keys['s_3']],dtype=tf.int32)*3)
width_list.append(tf.ones([self.df_keys['s_3']*n_g],dtype=tf.int32)*128)
else:
alpha_list.append(tf.concat([s_t3,s_t5],axis=0))
width_list.append(tf.ones([self.df_keys['s_3']*self.df_keys['batch_size']],dtype=tf.int32)*self.df_keys['width'])
stages_list.append(tf.ones([(self.df_keys['s_3']+self.df_keys['top_off'])*self.df_keys['batch_size']],dtype=tf.int32)*3)
width_list.append(tf.ones([self.df_keys['s_3']*n_g],dtype=tf.int32)*self.df_keys['width'])