Commit b00c6e15 authored by Jonathan Juhl's avatar Jonathan Juhl
Browse files

test

parent a54a942b
......@@ -90,11 +90,15 @@ class DynAE(super_class):
@tf.function
def train_g(self,image_0,take_components,angular, ctf_pr_count,t_x_pr_count,t_y_pr_count,t_z_pr_count,inplane_count,lambdas,spher_abb,ac):
add_images = []
image_0 = tf.math.l2_normalize(image_0,axis=(1,2))
# print(tf.reduce_min(image_0));exit()
catagorial = tf.cast(tf.one_hot(take_components,self.num_parts),self.precision)
angular = tf.cast(tf.one_hot(angular,self.angular_cluster**2),self.precision)
with tf.GradientTape() as t:
z,angels = self.cluster_Layer(catagorial,angular)
#print(z);exit()
fake_projections = self.G(z,angels,inplane_count,t_x_pr_count,t_y_pr_count,t_z_pr_count)
#print(fake_projections);exit()
add_images.append(fake_projections[0])
if self.ctf:
fake_projections = tf.expand_dims(apply_ctf(fake_projections,ctf_pr_count,lambdas,spher_abb,ac),axis=-1)
......@@ -126,6 +130,8 @@ class DynAE(super_class):
@tf.function
def train_d(self, x_real,take_components,angular,ctf_pr_count,t_x_pr_count,t_y_pr_count,t_z_pr_count,inplane_count,lambdas,spher_abb,ac):
x_real = tf.math.l2_normalize(x_real,axis=(1,2))
#print(tf.reduce_min(x_real));exit()
add_images = []
catagorial = tf.cast(tf.one_hot(take_components,self.num_parts),self.precision)
angular = tf.cast(tf.one_hot(angular,self.angular_cluster**2),self.precision)
......@@ -133,6 +139,7 @@ class DynAE(super_class):
with tf.GradientTape() as t:
z,angels = self.cluster_Layer(catagorial,angular)
fake_projections = self.G(z,angels,inplane_count,t_x_pr_count,t_y_pr_count,t_z_pr_count)
add_images.append(fake_projections[0])
if self.ctf:
fake_projections = tf.expand_dims(apply_ctf(fake_projections,ctf_pr_count,lambdas,spher_abb,ac),axis=-1)
......@@ -215,11 +222,12 @@ class DynAE(super_class):
if self.verbose:
image,y = next(dist_it)
image = image.numpy()
y = y.numpy()
else:
image = next(dist_it)
image = image.numpy()
#print(np.amin(image));exit()
image = tf.cast(image,self.precision)
loss_d,loss_style,loss_z,images = strategy.run(self.train_d,(image,takes,angular, ctf_pr_count,t_x_pr_count*0,t_y_pr_count*0,t_z_pr_count*0,inplane_count*0,lambdas,spher_abb,ac)) # loss_1,
......@@ -236,8 +244,9 @@ class DynAE(super_class):
else:
image = next(dist_it)
image = image.numpy()
for i in range(4):
for i in range(2):
takes = np.random.choice(np.arange(self.num_parts),self.batch_size)
angular = np.random.choice(np.arange(self.angular_cluster**2),self.batch_size)
ctf_pr_count,t_x_pr_count,t_y_pr_count,t_z_pr_count,inplane_count = draw_from_distribution(self.large_rescale,ctf_params,self.batch_size)
......
......@@ -56,7 +56,7 @@ class Cluster_Layer(Model):
epsilon_angel = tf.random.normal(shape=[s[0],2])
out_angel = epsilon_angel*logvar_angel+mean_angel
#;exit()
return tf.transpose(out),out_angel
def AdaIn_3D(inputs,s1,b1):
b= tf.shape(inputs)[0]
......
......@@ -100,7 +100,7 @@ class transform_3D(tf.keras.layers.Layer):
self.kernel = tf.cast(self.coordinates,tf.float32)
def rotation_map(self,voxels,alpha,beta,gamma,x_translate,y_translate,z_translate):
alpha = 2*3.1415*tf.cast(alpha,tf.float32)
beta =2*3.1415*tf.cast(beta,tf.float32)
gamma = 2*3.1415*tf.cast(gamma,tf.float32)
......@@ -193,7 +193,7 @@ class transform_3D(tf.keras.layers.Layer):
return out
def call(self,voxels,alpha,beta,gamma,x_translate,y_translate,z_translate):
image = tf.map_fn(lambda x: self.rotation_map(tf.squeeze(x[0]),x[1],x[2],x[3],x[4],x[5],x[6]),[voxels,alpha,beta,gamma,x_translate,y_translate,z_translate],dtype=tf.float32)
return image
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment