Commit 9351160d authored by Jonathan Juhl's avatar Jonathan Juhl
Browse files

corrected ctf, memory problems and the mrc loader

parent 8b361ec5
This diff is collapsed.
......@@ -55,7 +55,7 @@ def main():
parser.add_argument('--mb',type=int,default=1,help='pretrain the signal model')
parser.add_argument('--ss',type=int,default=20,help='pretrain the signal model')
parser.add_argument('--ss',type=int,default=10,help='pretrain the signal model')
parser.add_argument('--rs',type=int,default=32,help='pretrain the signal model')
......
......@@ -3,11 +3,43 @@
from super_clas_sortem import MaxBlurPooling2D,super_class
from tensorflow.keras import Model
import tensorflow as tf
from tensorflow.keras.layers import Layer, Flatten, Lambda, InputSpec, Input,LeakyReLU,Conv2DTranspose, Dense,Conv2D,GlobalAveragePooling2D,BatchNormalization,Activation,UpSampling2D,Conv3D,Conv3DTranspose,LeakyReLU,Dense,UpSampling3D,MaxPool3D,MaxPool2D
from tensorflow.keras.layers import Layer, Flatten, Lambda, InputSpec, Input,LeakyReLU,Conv2DTranspose, Dense,Conv2D,GlobalAveragePooling2D,BatchNormalization,Activation,UpSampling2D,Conv3D,Conv3DTranspose,LeakyReLU,Dense,UpSampling3D,MaxPool3D,MaxPool2D,ReLU
from utils_sortem import Spectral_norm,Instance_norm,transform_3D
class fit_layer(Model):
def __init__(self):
super(Cluster_Layer,self).__init__()
self.watersheed_layer = Watersheed_Layer()
def call(self,mean_matrix,mean_bias,variance_matrix,variance_bias):
loss_value = self.watersheed_layer(mean_matrix,mean_bias,variance_matrix,variance_bias)
return loss_value
class Cluster_Layer(Model):
def __init__(self):
super(Cluster_Layer,self).__init__()
self.dense_mean = Dense(4**3)
self.dense_var = Dense(4**3)
def call(self,catagorial_variable):
mean = self.dense_mean(catagorial_variable)
s =tf.shape(mean)
batch = s[0]
length = s[1]
epsilon = tf.random.normal(shape=[batch,length])
logvar = self.dense_var(catagorial_variable)
out = epsilon*logvar+mean
return out
def AdaIn_3D(inputs,s1,b1):
b= tf.shape(inputs)[0]
w= tf.shape(inputs)[1]
......@@ -52,7 +84,6 @@ class Double_Dense(Model):
c = self.dense_2(x)
return s,c
class Generator_AdaIN_Noise(Model):
def __init__(self,gf_dim=32):
......@@ -66,29 +97,29 @@ class Generator_AdaIN_Noise(Model):
self.h0 = AdaIn_2D
self.h0_a = LeakyReLU()
self.h1= Conv2DTranspose( self.gf_dim * 32,3,strides=2)
self.h1= Conv2DTranspose( self.gf_dim * 32,3,strides=2,padding='SAME')
self.z_map_1= Double_Dense(self.gf_dim * 32)
self.h1_aI = AdaIn_2D
self.h1 = LeakyReLU()
self.h2 = Conv2DTranspose(self.gf_dim * 16,3,strides=2)
self.h2 = Conv2DTranspose(self.gf_dim * 16,3,strides=2,padding='SAME')
self.h2_a = LeakyReLU()
self.z_map_2= Double_Dense(self.gf_dim * 16)
self.h2_aI = AdaIn_2D
self.h2 = LeakyReLU()
self.h5 = Conv2DTranspose(self.gf_dim*4,4,strides=2)
self.h5 = Conv2DTranspose(self.gf_dim*4,4,strides=2,padding='SAME')
self.z_map_4 = Double_Dense(self.gf_dim*4)
self.h6 = AdaIn_2D
self.h6_a = LeakyReLU()
self.h7 = Conv2DTranspose(self.gf_dim*2,4,strides=2)
self.h7 = Conv2DTranspose(self.gf_dim*2,4,strides=2,padding='SAME')
self.z_map_5 = Double_Dense(self.gf_dim*2)
self.h7_in = AdaIn_2D
self.h7_a = LeakyReLU()
self.h8 = Conv2DTranspose(self.gf_dim,4,strides=2)
self.h8 = Conv2DTranspose(self.gf_dim,4,strides=2,padding='SAME')
self.z_map_6 = Double_Dense( self.gf_dim)
self.h8_in = AdaIn_2D
self.h8_a = LeakyReLU()
......@@ -138,18 +169,19 @@ class Generator_AdaIN_Noise(Model):
class Generator_AdaIN_res128(Model):
def __init__(self,gf_dim=32):
def __init__(self,gf_dim=64):
super(Generator_AdaIN_res128,self).__init__()
self.trans_3d = transform_3D()
self.cluster = Cluster_Layer()
self.gf_dim = gf_dim
self.h0= Conv3DTranspose( self.gf_dim * 8,3,strides=2)
self.h0= Conv3DTranspose( self.gf_dim * 8,3,strides=2,padding='SAME')
self.zmap_0 = Double_Dense( self.gf_dim * 8)
self.h0_aI = AdaIn_3D
self.h0_a = LeakyReLU()
self.h1 = Conv3DTranspose(self.gf_dim * 4,3,strides=2)
self.h1 = Conv3DTranspose(self.gf_dim * 4,3,strides=2,padding='SAME')
self.z_map_1= Double_Dense(self.gf_dim * 4)
self.h1_aI = AdaIn_3D
self.h1_a = LeakyReLU()
......@@ -165,10 +197,10 @@ class Generator_AdaIN_res128(Model):
# h2_rotated = transform_3D(h2, view_in, 16, 16)
self.h2_proj1 = Conv3D( self.gf_dim*2,3)
self.h2_proj1 = Conv3D( self.gf_dim*2,3,padding='SAME')
self.h2_proj1_a = LeakyReLU()
self.h2_proj2 = Conv3D(self.gf_dim*2,3)
self.h2_proj2 = Conv3D(self.gf_dim*2,3,padding='SAME')
self.h2_proj2_a = LeakyReLU()
# =============================================================================================================
# Collapsing depth dimension
......@@ -179,44 +211,50 @@ class Generator_AdaIN_res128(Model):
self.h4 = Conv2D(self.gf_dim * 16,1)
self.h4_a = LeakyReLU()
self.h5 = Conv2DTranspose(self.gf_dim*8,4,strides=2)
self.h5 = Conv2DTranspose(self.gf_dim*8,4,strides=2,padding='SAME')
self.z_map_4 = Double_Dense(self.gf_dim*8)
self.h6 = AdaIn_2D
self.h6_a = LeakyReLU()
self.h7 = Conv2DTranspose(self.gf_dim*2,4,strides=2)
self.h7 = Conv2DTranspose(self.gf_dim*2,4,strides=2,padding='SAME')
self.z_map_5 = Double_Dense(self.gf_dim*2)
self.h7_in = AdaIn_2D
self.h7_a = LeakyReLU()
self.h8 = Conv2DTranspose(self.gf_dim,4,strides=2)
self.h8 = Conv2DTranspose(self.gf_dim,4,strides=2,padding='SAME')
self.z_map_6 = Double_Dense(self.gf_dim)
self.h8_in = AdaIn_2D
self.h8_a = LeakyReLU()
self.h9 = Conv2DTranspose(1,4,activation='tanh')
self.h9 = Conv2D(1,4,activation='tanh',padding='SAME')
def call(self,z,psi,phi,rho):
def call(self,psi,phi,rho):
z = self.cluster(z)
z = tf.reshape(z,[-1,4,4,4,1])
a,b = self.zmap_0(z)
x = self.h0(z)
x = self.h0_aI(x,a,b)
x = self.h0_a(x)
x = self.h1(x)
a,b = self.z_map_1(z)
x = self.h1_aI(x,a,b)
a,b = self.z_map_2(z)
x = self.h2_aI(x,a,b)
x = self.h2(x)
x = self.trans_3d(x,psi,phi,rho)
x = self.h2_proj1(x)
......@@ -229,92 +267,92 @@ class Generator_AdaIN_res128(Model):
x = self.h4_a(x)
x = self.h5(x)
a,b = self.z_map_4(x)
a,b = self.z_map_4(z)
x = self.h6(x,a,b)
x = self.h6_a(x)
x = self.h7(x)
a,b = self.z_map_5(x)
a,b = self.z_map_5(z)
x = self.h7_in(x,a,b)
x = self.h7_a(x)
x = self.h8(x)
a,b = self.z_map_6(x)
a,b = self.z_map_6(z)
x = self.h8_in(x,a,b)
x = self.h8_a(x)
x = self.h9(x)
return x
return x,z
class Discriminator_AdaIN_res128(Model):
def __init__(self,laten_z_dims,df_dim=32):
def __init__(self,laten_z_dims,df_dim=64):
super(Discriminator_AdaIN_res128,self).__init__()
self.laten_z_dims = laten_z_dims
self.df_dim = df_dim
self.instance_norm_0 = Instance_norm( True)
self.h0 = Spectral_norm(self.df_dim*4,strides=2)
self.h0_a = LeakyReLU()
self.h0_a = ReLU()
self.dh0 =Dense(1)
self.h1 = Spectral_norm(self.df_dim * 8,strides=2)
self.instance_norm_1 = Instance_norm( True)
self.dh1 = Dense(1)
self.h1_a = LeakyReLU()
self.h1_a = ReLU()
self.h2 = Spectral_norm( self.df_dim * 16,strides=2)
self.instance_norm_2 = Instance_norm(True)
self.dh2 = Dense(1)
self.h2_a = LeakyReLU()
self.h2_a = ReLU()
self.h3 = Spectral_norm( self.df_dim * 32,strides=2)
self.instance_norm_3 = Instance_norm( True)
self.h3_a = LeakyReLU()
self.h3_a = ReLU()
self.dh3 = Dense(1)
self.h4_a = LeakyReLU()
self.h4_a = ReLU()
#Returning logits to determine whether the images are real or fake
self.dense_out = Dense(1)
self.act_out = LeakyReLU()
self.act_out = ReLU()
self.encode = Dense(128)
self.predict = Dense(self.laten_z_dims)
self.encode_angels = Dense(128)
self.angles_out = LeakyReLU()
self.angles_out = ReLU()
self.angles =Dense(2,activation='sigmoid')
def style(self,x, h1_mean,h1_var ):
h1_mean = tf.reshape(h1_mean, (batch_size, self.df_dim * 2))
h1_var = tf.reshape(h1_var, (batch_size, self.df_dim * 2))
h1_mean = Flatten()(h1_mean)
h1_var = Flatten()(h1_var)
d_h1_style = tf.concat([h1_mean, h1_var], 0)
return d_h1_style
def call(self,x):
x = self.h0(x)
x, h1_mean,h1_var = self.instance_norm_1(x)
x, h1_mean,h1_var = self.instance_norm_0(x)
d_h1_style = self.style(x, h1_mean,h1_var)
d_logist_0 = self.dh0(d_h1_style)
d_sigmoid_0 = tf.nn.sigmoid(d_logist_0)
x = self.h0_a(x)
x = self.h1(x)
x, h2_mean,h2_var = self.instance_norm_2(x)
x, h2_mean,h2_var = self.instance_norm_1(x)
d_h2_style = self.style( x, h2_mean,h2_var )
d_logist_1 = self.dh1(d_h2_style)
d_sigmoid_1 = tf.nn.sigmoid(d_logist_1)
......@@ -325,23 +363,25 @@ class Discriminator_AdaIN_res128(Model):
x, h3_mean,h3_var = self.instance_norm_2(x)
d_h3_style = self.style( x, h3_mean,h3_var )
d_logist_2= self.dh2(d_h3_style)
d_sigmoid_2 = tf.nn.sigmoid(d_logist_3)
d_sigmoid_2 = tf.nn.sigmoid(d_logist_2)
x = self.h2_a(x)
x = self.h3(x)
x, h4_mean,h4_var = self.instance_norm_3(x)
d_h4_style = self.style( x, h4_mean,h4_var)
d_logist_3 = self.dh3(d_h4_style)
d_h3_style = self.style( x, h4_mean,h4_var)
d_logist_3 = self.dh3(d_h3_style)
d_sigmoid_3= tf.nn.sigmoid(d_logist_3)
x = self.h3_a(x)
x = Flatten()(x)
h5 = self.dense_out(x)
encode = self.encode(x)
latent_out = self.act_out(x)
cont_vars = self.predict(latent_out)
angels = self.encode_angels(x)
angels = self.angles_out(angels)
angels = self.angles(angels)
return tf.nn.sigmoid(h5), h5, tf.nn.tanh(cont_vars), d_logist_0, d_logist_1, d_logist_2, d_logist_3,angels
\ No newline at end of file
return tf.nn.sigmoid(h5), h5, tf.nn.tanh(cont_vars), d_sigmoid_0, d_sigmoid_1, d_sigmoid_2, d_sigmoid_3,angels
\ No newline at end of file
......@@ -10,7 +10,7 @@ class mrc_loader:
def __init__(self,mrc_path,path,bytes_pr_record,mode,prefix,width,precision,num_cpus,verbose,large_rescale=128,small_rescale=128,batch_size=300,abs_mins=0.08,mins=0.20,maxs=1.0,booleans=None,coordinates=None):
def __init__(self,mrc_path,path,bytes_pr_record,prefix,width,precision,num_cpus,verbose,large_rescale=128,batch_size=300):
self.num_cpus = num_cpus
self.mrc_paths = mrc_path
......@@ -18,58 +18,12 @@ class mrc_loader:
self.prefix = prefix
self.path = path
self.image_size = width
self.mins = mins
self.maxs = maxs
self.abs_mins = abs_mins
self.batch_size = batch_size
self.small_rescale = small_rescale
self.precision = tf.float32
self.precision = precision
self.large_rescale= large_rescale
self.bytes_pr_record= bytes_pr_record
self.verbose = verbose
self.boolean = booleans
self.coordinates = coordinates
self.image_feature_description = {'image': tf.io.FixedLenFeature([], tf.string),}
if mode == 'w':
filename =join(path,prefix+'_sinograms.tfrecord')
self.record = tf.io.TFRecordWriter(filename)
if mode == 'r':
l = np.linspace(-np.floor(width/2),np.ceil(width/2),width)
X,Y = np.meshgrid(l,l)
radius = np.sqrt(X**2+Y**2)
mask = radius< width/2
self.mask = tf.constant( mask.astype(float),dtype=tf.float32)
paths = []
for i in listdir(path):
if prefix in i:
paths.append(join(path,i))
self.prefix = prefix
self.paths = paths
def pretrain(self,raw_image,projected_image):
resized_raw = tf.image.resize(raw_image, [self.large_rescale,self.large_rescale])
resized_projected = tf.image.resize(projected_image, [self.large_rescale,self.large_rescale])
perm_image = tf.image.random_flip_left_right(tf.concat([resized_raw,resized_projected],axis=2))
perm_image = tf.image.random_flip_up_down(perm_image)
#perm_image = tf.image.random_crop(perm_image, [self.small_rescale,self.small_rescale,2])
return tf.split(perm_image,2,axis=2)
def generate(self,mode):
......@@ -79,13 +33,12 @@ class mrc_loader:
image = tf.cast(tf.io.decode_raw(ins,tf.float32),self.precision)
s = tf.cast(tf.sqrt(tf.cast(tf.shape(image)[0],tf.float32)),tf.int32)
image = tf.reshape(image,[s,s,1])
image = tf.image.per_image_standardization(image)
return image
def small_map(a,x):
if len(x) == 2:
x = x[0]
if self.verbose and mode =='predict':
if self.verbose:
a,y = a
a = tf.image.crop_and_resize(tf.expand_dims(a,axis=0),tf.expand_dims(x,axis=0),np.asarray([0]),[self.large_rescale,self.large_rescale])
a = tf.squeeze(a,axis=0)
......@@ -105,111 +58,25 @@ class mrc_loader:
def map_image(ins,labels=None):
if self.image_size < 64:
ins = tf.image.resize(ins,[64,64])
if self.image_size > 64 and self.image_size < 128:
ins = tf.image.resize(ins,[128,128])
if self.image_size > 128 and self.image_size < 256:
ins = tf.image.resize(ins,[256,256])
if mode == 'contrastive':
#images = sub_map(ins,self.mins,self.maxs,self.small_rescale)
return ins
else:
if self.verbose:
return ins,labels
else:
return ins
def tfrecord_reader(raw_string):
return tf.io.parse_single_example(raw_string, self.image_feature_description)['image']
def filters(a,x):
ins = tf.image.resize(ins,[self.large_rescale,self.large_rescale])
return x
if mode =='pretrain':
l_dataset = []
for a,b,c in zip(self.mrc_paths[0], self.mrc_paths[1],self.bytes_pr_record):
raw_data = tf.data.FixedLengthRecordDataset(a,c,num_parallel_reads=self.num_cpus, header_bytes=1024).map(premake_image,self.num_cpus).repeat()
projected_data = tf.data.FixedLengthRecordDataset(b,c,num_parallel_reads=self.num_cpus, header_bytes=1024).map(premake_image,self.num_cpus).repeat()
zipdata = tf.data.Dataset.zip((raw_data,projected_data)).map(self.pretrain,self.num_cpus)
l_dataset.append(zipdata)
d = l_dataset[0]
if len(l_dataset) > 1:
for i in l_dataset[1:]:
d.concatenate(i)
return d.prefetch(self.batch_size).batch(self.batch_size)
elif mode == 'predict':
if 'tfrecord' in self.mrc_paths:
data = tf.data.TFRecordDataset(self.mrc_paths).map(tfrecord_reader,self.num_cpus).map(premake_image,self.num_cpus)
if self.verbose:
return ins,labels
else:
# print(self.mrc_paths);exit()
data = tf.data.FixedLengthRecordDataset(['/emcc/misser11/EMPIAR_10317/sim_stack.mrcs','/emcc/misser11/EMPIAR_10317/sim_stack_2.mrcs'],self.bytes_pr_record,num_parallel_reads=2, header_bytes=1024).map(premake_image,self.num_cpus)
else:
if 'tfrecord' in self.mrc_paths:
data = tf.data.TFRecordDataset(self.mrc_paths).map(tfrecord_reader,self.num_cpus).map(premake_image,self.num_cpus).repeat()
else:
data = tf.data.FixedLengthRecordDataset(['/emcc/misser11/EMPIAR_10317/sim_stack.mrcs','/emcc/misser11/EMPIAR_10317/sim_stack_2.mrcs'],self.bytes_pr_record,num_parallel_reads=2, header_bytes=1024).repeat().map(premake_image,self.num_cpus)#.shuffle(2000).repeat()
if self.boolean is not None and self.coordinates is not None:
coords = tf.data.Dataset.from_tensor_slices(self.coordinates).repeat()
bools = tf.data.Dataset.from_tensor_slices(self.boolean).repeat()
coords = tf.data.Dataset.zip((coords,bools)).filter(lambda a,x: filters(a,x))
zipdata = tf.data.Dataset.zip((data,coords)).map(lambda a,x: small_map(a,x),self.num_cpus)
data = tf.data.Dataset.zip((zipdata,bools))
elif self.boolean is None and self.coordinates is not None:
coords = tf.data.Dataset.from_tensor_slices(self.coordinates).repeat()
data = tf.data.Dataset.zip((data,coords)).map(lambda a,x: small_map(a,x),self.num_cpus )
return ins
if mode =='predict' and self.verbose:
lab_1 = tf.data.Dataset.from_tensor_slices(np.load(join(self.path,'labels.npy')))
f = tf.data.Dataset.zip((data,lab_1)).map(lambda a,x: map_image(a,x),self.num_cpus).prefetch(self.batch_size).batch(self.batch_size)
else:
f = data.map(map_image,self.num_cpus).prefetch(self.batch_size).batch(self.batch_size)
data = tf.data.FixedLengthRecordDataset(self.mrc_paths,self.bytes_pr_record,num_parallel_reads=2, header_bytes=1024).map(premake_image,self.num_cpus).repeat()
if self.verbose:
lab_1 = tf.data.Dataset.from_tensor_slices(np.load(join(self.path,'labels.npy'))).repeat()
f = tf.data.Dataset.zip((data,lab_1)).map(lambda a,x: map_image(a,x),self.num_cpus).prefetch(self.batch_size).batch(self.batch_size)
else:
f = data
return f
#return data.map(map_image,self.num_cpus).prefetch(self.batch_size).batch(self.batch_size)
# else:
#return f.map(map_image,self.num_cpus).prefetch(self.batch_size).batch(self.batch_size)
......@@ -204,7 +204,7 @@ class super_class:
lr = (0.000004 *self.batch_size/256)* (1 + tf.cos(gs / total_steps * np.pi)) # the cosine decay learning rate, the weights of the model are updated slower and slower.
opt = tf.keras.optimizers.Adam(lr) # the optimizer
opt = tf.keras.optimizers.Adam(lr,beta_1=0.5,beta_2=0.999) # the optimizer
if self.half_precision: #
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt,loss_scale=123) # convert optimizer to handel float16
......
......@@ -5,9 +5,26 @@ from os.path import join
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import math
from tensorflow.keras.layers import Dense,LeakyReLU
#from tensorflow.keras.layers import Dense,conv3D
from tensorflow.keras.layers import Dense,LeakyReLU,Flatten
"""
class Watersheed_Layer(tf.keras.layers.Layer):
def __init__(self):
super(Watersheed_Layer, self).__init__()
def build(self,input_shape):
self.kernel = self.add_weight("offset",
shape=[input_shape[0]],input_shape[1],trainable=True)
def call(self,mean_matrix,mean_bias,variance_matrix,variance_bias):
return (self.kernel -(mean_matrix+mean_bias+variance_matrix+variance_bias))**2
"""
def loss_angels(predict,angels):
return tf.sqrt(tf.reduce_mean((predict-angels)**2))
......@@ -18,22 +35,37 @@ def loss_gen(D_logits_,predict_z,z,predict_angels,angels):
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logits_, tf.ones_like(D_logits_)))+loss_latent(predict_z,z)+loss_angels(predict_angels,angels)
return g_loss
def loss_disc(d_h1_r,d_h1_f,d_h2_r,d_h2_f,d_h3_r,d_h3_f,d_h4_r,d_h4_f,D_logits,D_logits_fake,predict_z,z,predict_angels,angels):
z = tf.cast(z,tf.float32)
def loss_encode(d_h1_r,d_h1_f,d_h2_r,d_h2_f,d_h3_r,d_h3_f,d_h4_r,d_h4_f):
d_h1_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(d_h1_r, tf.ones_like(d_h1_r))) \
+ tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(d_h1_f, tf.zeros_like(d_h1_f)))
+ tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(d_h1_f, tf.zeros_like(d_h1_f)))
d_h2_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(d_h2_r, tf.ones_like(d_h2_r))) \
+ tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(d_h2_f, tf.zeros_like(d_h2_f)))