Commit dff8d39b authored by Jonathan Juhl's avatar Jonathan Juhl
Browse files

new 3d view

parents 434e9c23 95fd1e85
import tensorflow as tf
import numpy as np
from os.path import join
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import math
from tensorflow.keras.layers import Dense,LeakyReLU
#from tensorflow.keras.layers import Dense,conv3D
def loss_angels(predict,angels):
return tf.sqrt(tf.reduce_mean((predict-angels)**2))
def loss_latent(predict,z):
return tf.sqrt(tf.reduce_mean((predict-z )**2))
def loss_gen(D_logits_,predict_z,z,predict_angels,angels):
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logits_, tf.ones_like(D_logits_)))+loss_latent(predict_z,z)+loss_angels(predict_angels,angels)
return g_loss
def loss_disc(d_h1_r,d_h1_f,d_h2_r,d_h2_f,d_h3_r,d_h3_f,d_h4_r,d_h4_,D_logits,D_logits_fake,predict_z,z,predict_angels,angels):
d_h1_loss = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_h1_r, tf.ones_like(d_h1_r))) \
+ tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_h1_f, tf.zeros_like(d_h1_f)))
d_h2_loss = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_h2_r, tf.ones_like(d_h2_r))) \
+ tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_h2_f, tf.zeros_like(d_h2_f)))
d_h3_loss = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_h3_r, tf.ones_like(d_h3_r))) \
+ tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_h3_f, tf.zeros_like(d_h3_f)))
d_h4_loss = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_h4_r, tf.ones_like(d_h4_r))) \
+ tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_h4_f, tf.zeros_like(d_h4_f)))
d_loss_real = tf.reduce_mean(sigmoid_cross_entropy_with_logits(D_logits, tf.ones_like(D_logits)))
d_loss_fake = tf.reduce_mean(sigmoid_cross_entropy_with_logits(D_logits_fake, tf.zeros_like(D_logits_fake)))
d_loss = d_h1_loss+d_h2_loss+d_h3_loss+d_h4_loss+d_loss_real+d_loss_fake+loss_latent(predict_z,z)+loss_angels(predict_angels,angels)
return d_loss
class transform_3D(tf.keras.layers.Layer):
def __init__(self):
super(transform_3D, self).__init__()
def build(self,input_shape):
self.dimensions = input_shape[1]
x = tf.range(self.dimensions)
X,Y,Z = tf.meshgrid(x,x,x)
self.coordinates = tf.stack([tf.reshape(X,[-1]), tf.reshape(Y,[-1]),tf.reshape(Z,[-1])],axis=0)
self.kernel = tf.cast(self.coordinates,tf.float32)
def rotation_map(self,voxels,alpha,beta,gamma):
alpha = 2*3.1415*tf.cast(alpha,tf.float32)
beta =2*3.1415*tf.cast(beta,tf.float32)
gamma = 2*3.1415*tf.cast(gamma,tf.float32)
voxels = tf.transpose(voxels,perm=[1,2,3,0,4])
dimensions = tf.shape(voxels)[0]
channels = tf.shape(voxels)[-1]
batchdim = tf.shape(voxels)[-2]
rotation_matrix_x =tf.stack([tf.constant(1.0),tf.constant(0.0),tf.constant(0.0),
tf.constant(0.0),tf.cos(alpha), -tf.sin(alpha),
tf.constant(0.0),tf.sin(alpha), tf.cos(alpha)])
rotation_matrix_y = tf.stack([
tf.cos(beta),tf.constant(0.0), tf.sin(beta),
tf.constant(0.0),tf.constant(1.0),tf.constant(0.0),
-tf.sin(beta),0, tf.cos(beta)])
rotation_matrix_z = tf.stack([
tf.cos(gamma), -tf.sin(gamma),tf.constant(0.0),
tf.sin(gamma), tf.cos(gamma),tf.constant(0.0),
tf.constant(0.0),tf.constant(0.0),tf.constant(1.0)])
rotation_matrix_x = tf.reshape(rotation_matrix_x, (3,3))
rotation_matrix_y = tf.reshape(rotation_matrix_y, (3,3))
rotation_matrix_z = tf.reshape(rotation_matrix_z, (3,3))
s = tf.matmul(rotation_matrix_x,tf.matmul(rotation_matrix_y,rotation_matrix_z))
r = tf.matmul(tf.matmul(rotation_matrix_x,tf.matmul(rotation_matrix_y,rotation_matrix_z)) ,self.kernel)
x,y,z = tf.split(r,3,axis=0)
X = tf.reshape(x,[-1])
Y = tf.reshape(y,[-1])
Z = tf.reshape(z,[-1])
X_lower = tf.math.floor(X) # tf.clip_by_value(tf.math.floor(X),0,self.dimensions)
X_upper = tf.math.ceil(X) # tf.clip_by_value(tf.math.ceil(X),0,self.dimensions)
Y_lower = tf.math.floor(Y) # tf.clip_by_value(tf.math.floor(Y),0,self.dimensions)
Y_upper = tf.math.ceil(Y) # tf.clip_by_value(tf.math.ceil(Y),0,self.dimensions)
Z_lower = tf.math.floor(Z) # tf.clip_by_value(tf.math.floor(Z),0,self.dimensions)
Z_upper = tf.math.ceil(Z) #tf.clip_by_value(tf.math.ceil(Z),0,self.dimensions)
x_d = (X-X_lower+0.001)/(X_upper-X_lower+0.001)
y_d = (Y-Y_lower+0.001)/(Y_upper-Y_lower+0.001)
z_d = (Z-Z_lower+0.001)/(Z_upper-Z_lower+0.001)
coord_000 = tf.stack([X_lower,Y_lower,Z_lower],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
coord_001 = tf.stack([X_lower,Y_lower,Z_upper],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
coord_011 = tf.stack([X_lower,Y_upper,Z_upper],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
coord_111 = tf.stack([X_upper,Y_upper,Z_upper],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
coord_101 = tf.stack([X_upper,Y_lower,Z_upper],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
coord_100 = tf.stack([X_upper,Y_lower,Z_lower],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
coord_010 = tf.stack([X_lower,Y_upper,Z_lower],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
coord_110 = tf.stack([X_upper,Y_upper,Z_lower],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
c000 = tf.gather_nd(voxels,tf.cast(coord_000,tf.int32))
c001 = tf.gather_nd(voxels,tf.cast(coord_001,tf.int32))
c011 = tf.gather_nd(voxels,tf.cast(coord_011,tf.int32))
c111 = tf.gather_nd(voxels,tf.cast(coord_111,tf.int32))
c101 = tf.gather_nd(voxels,tf.cast(coord_101,tf.int32))
c100 = tf.gather_nd(voxels,tf.cast(coord_100,tf.int32))
c010 = tf.gather_nd(voxels,tf.cast(coord_010,tf.int32))
c110 = tf.gather_nd(voxels,tf.cast(coord_110,tf.int32))
x_d = tf.expand_dims(tf.expand_dims(x_d,axis=1),axis=1)
y_d = tf.expand_dims(tf.expand_dims(y_d,axis=1),axis=1)
z_d = tf.expand_dims(tf.expand_dims(z_d,axis=1),axis=1)
c00 = c000*(1-x_d) + c100*x_d
c01 = c001*(1-x_d) + c101*x_d
c10 = c010*(1-x_d) + c110*x_d
c11 = c011*(1-x_d) + c111*x_d
c0 = c00*(1-y_d) + c10*y_d
c1 = c01*(1-y_d) + c11*y_d
c = c0*(1-z_d)+c1*z_d
c = tf.transpose(c,perm=[1,0,2])
out = tf.reshape(c,[batchdim,dimensions,dimensions,dimensions,channels])
return out
def call(self,voxels,alpha,beta,gamma):
image =tf.expand_dims(tf.map_fn(lambda x: self.rotation_map(tf.squeeze(voxels),x[0],x[1],x[2]),[alpha,beta,gamma],dtype=tf.float32),axis=-1)
s = tf.shape(image)
image = tf.reshape(image,[s[0]*s[1],s[2],s[3],s[4],s[5]])
return image
def matmul_func(first,second,third):
a = tf.matmul(first,second,transpose_a=True)#*third
a = tf.tile(a,[2,1,1])
third = tf.transpose(third,perm=[2,0,1])
return a*third
def matmul_func_2D(first,second):
a = tf.matmul(first,second,transpose_a=True)#*third
return tf.reshape(tf.reshape(a,[-1]),[2,2])
class Instance_norm(tf.keras.layers.Layer):
def __init__(self,return_mean):
super(Instance_norm, self).__init__()
self.return_mean = return_mean
def build(self,input_shape):
self.offset = self.add_weight("offset",
shape=[input_shape[-1]],trainable=True)
self.scale = self.add_weight("scale",
shape=[input_shape[-1]],trainable=True)
def call(self,x):
mean, variance = tf.nn.moments(x, axes=[1, 2], keep_dims=True)
epsilon = 1e-5
inv = 1/tf.sqrt(variance + epsilon)
normalized = (input - mean) * inv
if self.return_mean:
return self.scale * normalized + self.offset, mean, variance
else:
return self.scale * normalized + self.offset
class Spectral_norm(tf.keras.layers.Layer):
def __init__(self,channels,kernels=3,strides=2):
super(Spectral_norm, self).__init__()
self.channels = channels
self.strides = strides
self.kernels = kernels
def build(self,input_shape):
self.u = self.add_weight("offset",
shape=[self.kernels,self.kernels,input_shape[-1],output_dim])
self.bias = self.add_weight("offset",
shape=[1,input_shape[-1]])
def call(self,x):
w_shape = tf.shape(x)
x = tf.reshape(x, [-1, w_shape[-1]])
u_hat = self.u
v_ = tf.matmul(u_hat, tf.transpose(w))
v_hat = l2_norm(v_)
u_ = tf.matmul(v_hat, w)
u_hat = l2_norm(u_)
sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
w_norm = w / sigma
w_norm = tf.reshape(w_norm, w_shape)
conv = tf.nn.conv2d(input_, w_norm, strides=[1, self.strides, self.strides, 1], padding='SAME')+self.bias
return conv
def draw_from_distribution(width,ctf_params,projections_count):
translate_x= translate_y = np.linspace(-np.floor(width/2),np.ceil(width/2),projections_count)
gauss_prob = np.exp(-(translate_x**2)/(width/2)**2)
gauss_prob = gauss_prob/np.sum(gauss_prob)
a = np.arange(ctf_params.shape[0])
sa = np.random.choice(a,projections_count)
ctf_pr_count = np.take_along_axis(ctf_params,np.expand_dims(sa,axis=1),axis=0)
t_x_pr_count = np.random.choice(translate_x,projections_count,p=gauss_prob)
t_y_pr_count = np.random.choice(translate_y,projections_count,p=gauss_prob)
psi_count = np.random.choice(np.random.uniform(0,1,36),projections_count)
rho_count = np.random.choice(np.random.uniform(0,1,36),projections_count)
inplane_count = np.random.choice(np.random.uniform(0,1,36),projections_count)
return ctf_pr_count,t_x_pr_count,t_y_pr_count,psi_count,rho_count,inplane_count
def projection(model,gamma,beta,alpha,translations):
gamma = tf.cast(gamma,tf.float32)
beta = tf.cast(beta,tf.float32)
alpha = tf.cast(alpha,tf.float32)
translations = tf.cast(translations,tf.float32)
rot_model = tf.transpose(tfa.image.rotate(model,2*3.1415*gamma,interpolation='bilinear'),perm=[1,0,2])
rot_model = tfa.image.rotate(rot_model,2*3.1415*beta,interpolation='bilinear')
#projection =tfa.image.rotate(tf.reduce_sum(rot_model,axis=2),2*3.1415*alpha,interpolation='bilinear') # tfa.image.translate(tfa.image.rotate(tf.reduce_sum(rot_model,axis=2),2*3.1415*alpha,interpolation='bilinear'),translations,interpolation='bilinear')
return tf.expand_dims(tf.reduce_sum(rot_model,axis=2),axis=-1)
def projection_map(model,gamma,beta,alpha,translations):
model = tf.squeeze(tf.cast(model,tf.float32))
images = tf.map_fn(lambda x: projection(model,x[0],x[1],x[2],x[3]),[gamma,beta,alpha,translations],dtype=tf.float32)
return images
def translate(image_stack,translate_x,translate_y):
shifted_images = tfa.image.translate(tf.expand_dims(image_stack,axis=-1),tf.cast(tf.stack([translate_y,translate_x],axis=1),tf.float32))
return tf.expand_dims(image_stack,axis=-1) #image_stack
def mm_models(model_size,num_linear_params):
model_vectors = np.random.uniform(low=0.0, high=1.0, size=[model_size**3,num_linear_params])
q,r = np.linalg.qr(model_vectors)
return q
def sample_params(batch_size):
a = np.random.uniform(size=[batch_size])
#a = a/np.reduce_sum(a,axis=1)
return a
def apply_ctf(image,ctf_params,gamma,spherical_abberation,w_2):
image = tf.cast(image,tf.float32)
ctf_params = tf.cast(ctf_params,tf.float32)
gamma = tf.cast(gamma,tf.float32)
spherical_abberation = tf.cast(spherical_abberation,tf.float32)
w_2 = tf.cast(w_2,tf.float32)
scale_dimension = tf.shape(image)[2]
phase,defocus_1,defocus_2,ast = tf.split(ctf_params,4,axis=1)
w_1 = tf.sqrt(1-w_2)
linear_map = tf.linspace(-1+0.001,1+0.001,scale_dimension)
X,Y = tf.meshgrid(linear_map,linear_map)
a_g = tf.reshape(tf.math.tan(Y/X),[-1])
a_g = tf.tile(tf.expand_dims(a_g,axis=0),[tf.shape(ctf_params)[0],1])
g = X**2+Y**2
g = tf.reshape(g,[-1])
g = tf.tile(tf.expand_dims(g,axis=0),[tf.shape(ctf_params)[0],1])
delta_f = 0.5*(defocus_1+defocus_2+(defocus_1-defocus_2)*tf.cos(2*(a_g-ast)))
t = tf.math.atan(w_2/w_1)
xi = 3.1415*gamma*g*(delta_f-0.5*gamma**2*g*spherical_abberation) + phase + t
CTF = -tf.math.sin(xi)
s = tf.cast(tf.sqrt(tf.cast(tf.shape(CTF)[1],tf.float32)),tf.int32)
CTF = tf.cast(tf.reshape(CTF,[tf.shape(CTF)[0],s,s]),tf.complex64)
image = tf.cast(image,tf.complex64)
ctf_image = tf.signal.ifft2d(tf.signal.fftshift(tf.signal.fftshift(tf.signal.fft2d(image),axes=(1,2))*CTF,axes=(1,2)))
return tf.cast(image,tf.float32)
def z_maker(model,num_linear_params):
k = np.random.random(num_linear_params)
#k = k/np.sum(k)
z = np.dot(model,k)
s = np.cbrt(z.shape[0]).astype(np.int32)
z = np.reshape(z,[1,s,s,s,1])
z = np.concatenate([z]*10,axis=0)
return z,k
def project_a_model(self,model):
side_1 = tf.reduce_sum(model,axis=1)
side_2 = tf.reduce_sum(model,axis=2)
side_3 = tf.reduce_sum(model,axis=3)
axis_one = tf.concat([tf.zeros([scale,scale]),side_1],axis=2)
axis_two = tf.concat([side_2,side_3],axis=2)
combined = tf.concat([axis_one,axis_two],axis=1)
combined = tf.concat([tf.split(combined,tf.shape(combined)[0])],axis=0)
return combined
"""
c = '/emcc/misser11/sortem_old/unetvoid/particle_stack_dir'
lambdas = np.load(join(c,'electron_volts.npy'))
spher_abb = np.load(join(c,'spherical_abberation.npy'))
ac = np.load(join(c,'amplitude_contrast.npy'))
ctf_params = np.load(join(c,'ctf_params.npy'))
projections_count = 100
ctf_pr_count,t_x_pr_count,t_y_pr_count,psi_count,rho_count,inplane_count = draw_from_distribution(32,ctf_params,projections_count)
s = mm_models(8,3)
out = z_maker(s,3)
model = np.ones([32,32,32])
proj = projection_map(model,inplane_count,psi_count,rho_count)
proj = apply_ctf(proj,ctf_pr_count,lambdas,spher_abb,ac)
proj = translate(proj,t_x_pr_count,t_y_pr_count)
"""
"""
def rotation_map(voxels,alpha,beta,gamma):
dimensions = tf.shape(voxels)[0]
x = tf.cast(tf.range(-tf.cast(tf.floor(tf.cast(dimensions,tf.float32)/2.0),tf.int32),tf.cast(tf.cast(dimensions,tf.float32)-tf.floor(tf.cast(dimensions,tf.float32)/2.0),tf.int32)),tf.float32)
X,Y,Z = tf.meshgrid(x,x,x)
coordinates = tf.stack([tf.reshape(X,[-1]), tf.reshape(Y,[-1]),tf.reshape(Z,[-1])],axis=0)
kernel = tf.cast(coordinates,tf.float32)
alpha = 2*3.1415*tf.cast(alpha,tf.float32)
beta = 2*3.1415*tf.cast(beta,tf.float32)
gamma = 2*3.1415*tf.cast(gamma,tf.float32)
rotation_matrix_x =tf.stack([tf.constant(1.0),tf.constant(0.0),tf.constant(0.0),
tf.constant(0.0),tf.cos(alpha), -tf.sin(alpha),
tf.constant(0.0),tf.sin(alpha), tf.cos(alpha)])
rotation_matrix_y = tf.stack([
tf.cos(beta),tf.constant(0.0), tf.sin(beta),
tf.constant(0.0),tf.constant(1.0),tf.constant(0.0),
-tf.sin(beta),0, tf.cos(beta)])
rotation_matrix_z = tf.stack([
tf.cos(gamma), -tf.sin(gamma),tf.constant(0.0),
tf.sin(gamma), tf.cos(gamma),tf.constant(0.0),
tf.constant(0.0),tf.constant(0.0),tf.constant(1.0)])
rotation_matrix_x = tf.reshape(rotation_matrix_x, (3,3))
rotation_matrix_y = tf.reshape(rotation_matrix_y, (3,3))
rotation_matrix_z = tf.reshape(rotation_matrix_z, (3,3))
s = tf.matmul(rotation_matrix_x,tf.matmul(rotation_matrix_y,rotation_matrix_z))
r = tf.matmul(tf.matmul(rotation_matrix_x,tf.matmul(rotation_matrix_y,rotation_matrix_z)) ,kernel)
x,y,z = tf.split(r,3,axis=0)
X = tf.reshape(x,[-1])
Y = tf.reshape(y,[-1])
Z = tf.reshape(z,[-1])
X_lower = tf.math.floor(X) # tf.clip_by_value(tf.math.floor(X),0,self.dimensions)
X_upper = tf.math.ceil(X) # tf.clip_by_value(tf.math.ceil(X),0,self.dimensions)
Y_lower = tf.math.floor(Y) # tf.clip_by_value(tf.math.floor(Y),0,self.dimensions)
Y_upper = tf.math.ceil(Y) # tf.clip_by_value(tf.math.ceil(Y),0,self.dimensions)
Z_lower = tf.math.floor(Z) # tf.clip_by_value(tf.math.floor(Z),0,self.dimensions)
Z_upper = tf.math.ceil(Z) #tf.clip_by_value(tf.math.ceil(Z),0,self.dimensions)
x_d = (X-X_lower+0.001)/(X_upper-X_lower+0.001)
y_d = (Y-Y_lower+0.001)/(Y_upper-Y_lower+0.001)
z_d = (Z-Z_lower+0.001)/(Z_upper-Z_lower+0.001)
coord_000 = tf.stack([X_lower,Y_lower,Z_lower],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
coord_001 = tf.stack([X_lower,Y_lower,Z_upper],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
coord_011 = tf.stack([X_lower,Y_upper,Z_upper],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
coord_111 = tf.stack([X_upper,Y_upper,Z_upper],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
coord_101 = tf.stack([X_upper,Y_lower,Z_upper],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
coord_100 = tf.stack([X_upper,Y_lower,Z_lower],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
coord_010 = tf.stack([X_lower,Y_upper,Z_lower],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
coord_110 = tf.stack([X_upper,Y_upper,Z_lower],axis=1)+tf.cast(tf.floor(dimensions/2),tf.float32)
c000 = tf.gather_nd(voxels,tf.cast(coord_000,tf.int32))
c001 = tf.gather_nd(voxels,tf.cast(coord_001,tf.int32))
c011 = tf.gather_nd(voxels,tf.cast(coord_011,tf.int32))
c111 = tf.gather_nd(voxels,tf.cast(coord_111,tf.int32))
c101 = tf.gather_nd(voxels,tf.cast(coord_101,tf.int32))
c100 = tf.gather_nd(voxels,tf.cast(coord_100,tf.int32))
c010 = tf.gather_nd(voxels,tf.cast(coord_010,tf.int32))
c110 = tf.gather_nd(voxels,tf.cast(coord_110,tf.int32))
c00 = c000*(1-x_d) + c100*x_d
c01 = c001*(1-x_d) + c101*x_d
c10 = c010*(1-x_d) + c110*x_d
c11 = c011*(1-x_d) + c111*x_d
c0 = c00*(1-y_d) + c10*y_d
c1 = c01*(1-y_d) + c11*y_d
c = c0*(1-z_d)+c1*z_d
return tf.reduce_sum(tf.reshape(c,[dimensions,dimensions,dimensions]),axis=0)
import mrcfile
with mrcfile.open('/emcc/misser11/EMPIAR_10317/emd_4775.mrc') as mrc:
voxel = mrc.data
import matplotlib.pyplot as plt
im = rotation_map(voxel,0.0,0.0,0.3)
plt.imshow(im)
plt.savefig('test_it.png')
exit()
"""
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment