Commit a54a942b authored by Jonathan Juhl's avatar Jonathan Juhl
Browse files

correct loss function

parent 90bb6c10
......@@ -55,21 +55,6 @@ class DynAE(super_class):
return prob_density
@tf.function
def interpolation(class_centroids,num_interpolants):
normalized = tf.l2_norm(class_centroids)
tmp = tf.matmul(normalized,normalized,transpose_b=True)
t_target = tf.argmin(tmp,axis=1)
t_source = tf.argmin(tf.rduce_min(tmp,axis=1))
lin_space = tf.linspace(0,1,self.interpolation_num_samples)
strait_line_vector = normalized[t_target]*lin_space+(1-lin_space)*normalized[t_source]
softweighted_vectors = tf.matmul( tf.nn.softmax(tf.matmul(strait_line_vector,normalized,transpose_b=True)),class_centroids,transpose_b=True)
return softweighted_vectors
@tf.function
def predict_cluster(self,num_classes,angular,images):
num_classes = tf.one_hot(num_classes,self.num_parts)
......@@ -132,12 +117,12 @@ class DynAE(super_class):
lz = loss_latent(predict_z_f,catagorial)
loss_tot += lz
lg = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(h5_sig_f, tf.ones_like(h5_sig_f)))
lg = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(tf.ones_like(h5_sig_f),h5_sig_f))
loss_tot += lg
grad = t.gradient(loss_tot, self.G.trainable_variables+self.cluster_Layer.trainable_variables)
self.g_opt.apply_gradients(zip(grad, self.G.trainable_variables+self.cluster_Layer.trainable_variables))
return lz,lg,add_images
return lg,lg,add_images
@tf.function
def train_d(self, x_real,take_components,angular,ctf_pr_count,t_x_pr_count,t_y_pr_count,t_z_pr_count,inplane_count,lambdas,spher_abb,ac):
......@@ -223,25 +208,10 @@ class DynAE(super_class):
for i in range(2):
takes = np.random.choice(np.arange(self.num_parts),self.batch_size)
angular = np.random.choice(np.arange(self.angular_cluster**2),self.batch_size)
ctf_pr_count,t_x_pr_count,t_y_pr_count,t_z_pr_count,inplane_count = draw_from_distribution(self.large_rescale,ctf_params,self.batch_size)
if self.verbose:
image,y = next(dist_it)
image = image.numpy()
y = y.numpy()
else:
image = next(dist_it)
image = image.numpy()
image = tf.cast(image,self.precision)
loss_d,loss_style,loss_z,images = strategy.run(self.train_d,(image,takes,angular, ctf_pr_count,t_x_pr_count,t_y_pr_count,t_z_pr_count,inplane_count,lambdas,spher_abb,ac)) # loss_1,
l_list.append([loss_d,loss_style,loss_z])
takes = np.random.choice(np.arange(self.num_parts),self.batch_size)
angular = np.random.choice(np.arange(self.angular_cluster**2),self.batch_size)
ctf_pr_count,t_x_pr_count,t_y_pr_count,t_z_pr_count,inplane_count = draw_from_distribution(self.large_rescale,ctf_params,self.batch_size)
if self.verbose:
image,y = next(dist_it)
image = image.numpy()
......@@ -250,15 +220,32 @@ class DynAE(super_class):
image = next(dist_it)
image = image.numpy()
takes = np.random.choice(np.arange(self.num_parts),self.batch_size)
angular = np.random.choice(np.arange(self.angular_cluster**2),self.batch_size)
ctf_pr_count,t_x_pr_count,t_y_pr_count,t_z_pr_count,inplane_count = draw_from_distribution(self.large_rescale,ctf_params,self.batch_size)
loss_g,loss_q,images = strategy.run(self.train_g,(image,takes,angular, ctf_pr_count,t_x_pr_count,t_y_pr_count,t_z_pr_count,inplane_count,lambdas,spher_abb,ac ))
image = tf.cast(image,self.precision)
loss_d,loss_style,loss_z,images = strategy.run(self.train_d,(image,takes,angular, ctf_pr_count,t_x_pr_count*0,t_y_pr_count*0,t_z_pr_count*0,inplane_count*0,lambdas,spher_abb,ac)) # loss_1,
loss_d = loss_d.numpy()
loss_style = loss_style.numpy()
loss_style = loss_style.numpy()
loss_z = loss_z.numpy()
print('training time: ', time() - t0,"step: %i of %i " %(ite,self.steps),loss_d,loss_style,loss_z)
l_list.append([loss_d,loss_style,loss_z])
if self.verbose:
image,y = next(dist_it)
image = image.numpy()
y = y.numpy()
else:
image = next(dist_it)
image = image.numpy()
for i in range(4):
takes = np.random.choice(np.arange(self.num_parts),self.batch_size)
angular = np.random.choice(np.arange(self.angular_cluster**2),self.batch_size)
ctf_pr_count,t_x_pr_count,t_y_pr_count,t_z_pr_count,inplane_count = draw_from_distribution(self.large_rescale,ctf_params,self.batch_size)
loss_q,loss_g,images = strategy.run(self.train_g,(image,takes,angular, ctf_pr_count*0,t_x_pr_count*0,t_y_pr_count*0,t_z_pr_count*0,inplane_count*0,lambdas,spher_abb,ac ))
loss_g = loss_g.numpy()
print('training time: ', time() - t0,"step: %i of %i " %(ite,self.steps),loss_d,loss_style,loss_z,loss_g)
take = []
#print(ite % self.validate_interval)
......
import tensorflow as tf
import matplotlib
matplotlib.use('Agg')
......@@ -47,22 +49,22 @@ def loss_latent(predict,catagories):
def loss_gen(D_logits_,predict_z,z):
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logits_, tf.ones_like(D_logits_)))+loss_latent(predict_z,z)
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(tf.ones_like(D_logits_),D_logits_))+loss_latent(predict_z,z)
return g_loss
def loss_encode(d_h1_r,d_h1_f,d_h2_r,d_h2_f,d_h3_r,d_h3_f,d_h4_r,d_h4_f):
d_h1_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(d_h1_r, tf.ones_like(d_h1_r))) \
+ tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(d_h1_f, tf.zeros_like(d_h1_f)))
d_h1_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(tf.ones_like(d_h1_r),d_h1_r)) \
+ tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(tf.zeros_like(d_h1_f),d_h1_f))
d_h2_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(d_h2_r, tf.ones_like(d_h2_r))) \
+ tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(d_h2_f, tf.zeros_like(d_h2_f)))
d_h2_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( tf.ones_like(d_h2_r),d_h2_r)) \
+ tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(tf.zeros_like(d_h2_f),d_h2_f))
d_h3_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(d_h3_r, tf.ones_like(d_h3_r))) \
+ tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(d_h3_f, tf.zeros_like(d_h3_f)))
d_h3_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(tf.ones_like(d_h3_r),d_h3_r)) \
+ tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(tf.zeros_like(d_h3_f),d_h3_f))
d_h4_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(d_h4_r, tf.ones_like(d_h4_r))) \
+ tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(d_h4_f, tf.zeros_like(d_h4_f)))
d_h4_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(tf.ones_like(d_h4_r),d_h4_r)) \
+ tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(tf.zeros_like(d_h4_f),d_h4_f))
return d_h1_loss+d_h2_loss+d_h3_loss+d_h4_loss
......@@ -70,8 +72,8 @@ def loss_encode(d_h1_r,d_h1_f,d_h2_r,d_h2_f,d_h3_r,d_h3_f,d_h4_r,d_h4_f):
def loss_disc(D_logits,D_logits_fake,predict_z,z):
z = tf.cast(z,tf.float32)
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logits, tf.ones_like(D_logits)))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logits_fake, tf.zeros_like(D_logits_fake)))
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(tf.ones_like(D_logits),D_logits))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(tf.zeros_like(D_logits_fake),D_logits_fake))
d_loss = d_loss_real+d_loss_fake
......@@ -87,7 +89,9 @@ class transform_3D(tf.keras.layers.Layer):
self.full_image = full_image
def build(self,input_shape):
self.dimensions = input_shape[1]
self.dimensions = input_shape[-2]
self.channels = input_shape[-1]
x = tf.range(-int(np.floor(self.dimensions/2)),int(np.ceil(self.dimensions/2)))
X,Y,Z = tf.meshgrid(x,x,x)
......@@ -104,9 +108,7 @@ class transform_3D(tf.keras.layers.Layer):
y_translate = tf.cast(y_translate,tf.float32)
z_translate = tf.cast(z_translate,tf.float32)
#voxels = tf.transpose(voxels,perm=[1,2,3,0,4])
dimensions = tf.shape(voxels)[0]
channels = tf.shape(voxels)[-1]
#batchdim = tf.shape(voxels)[-2]
rotation_matrix_x =tf.stack([tf.ones_like(alpha),tf.zeros_like(alpha),tf.zeros_like(alpha),
......@@ -127,11 +129,11 @@ class transform_3D(tf.keras.layers.Layer):
rotation_matrix_y = tf.reshape(rotation_matrix_y, (3,3))
rotation_matrix_z = tf.reshape(rotation_matrix_z, (3,3))
s = tf.matmul(rotation_matrix_x,tf.matmul(rotation_matrix_y,rotation_matrix_z))
r = tf.matmul(tf.matmul(rotation_matrix_x,tf.matmul(rotation_matrix_y,rotation_matrix_z)) ,self.kernel)
x,y,z = tf.split(r,3,axis=0)
X = tf.reshape(x,[-1])+x_translate*(self.dimensions/self.full_image)
Y = tf.reshape(y,[-1])+y_translate*(self.dimensions/self.full_image)
Z = tf.reshape(z,[-1])+z_translate*(self.dimensions/self.full_image)
......@@ -148,19 +150,19 @@ class transform_3D(tf.keras.layers.Layer):
y_d = (Y-Y_lower+0.001)/(Y_upper-Y_lower+0.001)
z_d = (Z-Z_lower+0.001)/(Z_upper-Z_lower+0.001)
coord_000 = tf.stack([X_lower,Y_lower,Z_lower],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
coord_001 = tf.stack([X_lower,Y_lower,Z_upper],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
coord_011 = tf.stack([X_lower,Y_upper,Z_upper],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
coord_111 = tf.stack([X_upper,Y_upper,Z_upper],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
coord_101 = tf.stack([X_upper,Y_lower,Z_upper],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
coord_100 = tf.stack([X_upper,Y_lower,Z_lower],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
coord_010 = tf.stack([X_lower,Y_upper,Z_lower],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
coord_110 = tf.stack([X_upper,Y_upper,Z_lower],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
coord_000 = tf.stack([X_lower,Y_lower,Z_lower],axis=1)+tf.cast(tf.floor(self.dimensions/2),tf.float32)
coord_001 = tf.stack([X_lower,Y_lower,Z_upper],axis=1)+tf.cast(tf.floor(self.dimensions/2),tf.float32)
coord_011 = tf.stack([X_lower,Y_upper,Z_upper],axis=1)+tf.cast(tf.floor(self.dimensions/2),tf.float32)
coord_111 = tf.stack([X_upper,Y_upper,Z_upper],axis=1)+tf.cast(tf.floor(self.dimensions/2),tf.float32)
coord_101 = tf.stack([X_upper,Y_lower,Z_upper],axis=1)+tf.cast(tf.floor(self.dimensions/2),tf.float32)
coord_100 = tf.stack([X_upper,Y_lower,Z_lower],axis=1)+tf.cast(tf.floor(self.dimensions/2),tf.float32)
coord_010 = tf.stack([X_lower,Y_upper,Z_lower],axis=1)+tf.cast(tf.floor(self.dimensions/2),tf.float32)
coord_110 = tf.stack([X_upper,Y_upper,Z_lower],axis=1)+tf.cast(tf.floor(self.dimensions/2),tf.float32)
#voxels = tf.reshape(voxels,[dimensions**3,channels])
c000 = tf.gather_nd(voxels,tf.cast(coord_000,tf.int32))
# print(c000);exit()
c001 = tf.gather_nd(voxels,tf.cast(coord_001,tf.int32))
c011 = tf.gather_nd(voxels,tf.cast(coord_011,tf.int32))
......@@ -175,7 +177,7 @@ class transform_3D(tf.keras.layers.Layer):
z_d = tf.expand_dims(z_d,axis=1)
c00 = c000*(1-x_d) + c100*x_d
c01 = c001*(1-x_d) + c101*x_d
c10 = c010*(1-x_d) + c110*x_d
c11 = c011*(1-x_d) + c111*x_d
......@@ -186,8 +188,8 @@ class transform_3D(tf.keras.layers.Layer):
c = c0*(1-z_d)+c1*z_d
out = tf.reshape(c,[dimensions,dimensions,dimensions,channels])
out = tf.reshape(c,[self.dimensions,self.dimensions,self.dimensions,self.channels])
return out
def call(self,voxels,alpha,beta,gamma,x_translate,y_translate,z_translate):
......@@ -384,9 +386,17 @@ def apply_ctf(image,ctf_params,KVolts,spherical_abberation,w2):
return ctf_image
#import mrcfile
"""
data = np.load('data.npy')
data = np.reshape(data,(1,256,256,256,1))
t = transform_3D(256)
v = np.asarray([0.0])
k = np.asarray([128])
out = t(data,v,v,v,k,k,v)
plt.imshow(np.squeeze(np.sum(out,axis=1)))
plt.savefig('try_it.png')
#import mrcfile"""
"""V = 300
lambdas = 10**(-4)*10**(-6)*12.25*10**(-10)/np.sqrt(V*10**3)
......@@ -431,4 +441,4 @@ with mrcfile.open('/emcc/misser11/EMPIAR_10317/out_noise.mrcs') as mrc:
plt.clf()
t+=1
"""
\ No newline at end of file
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment