Commit 9a47d116 authored by Jonathan Juhl's avatar Jonathan Juhl
Browse files

Update utils_sortem.py

parent 95fd1e85
import tensorflow as tf
import numpy as np
from os.path import join
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import math
from tensorflow.keras.layers import Dense,LeakyReLU
......@@ -57,23 +57,11 @@ class transform_3D(tf.keras.layers.Layer):
beta =2*3.1415*tf.cast(beta,tf.float32)
gamma = 2*3.1415*tf.cast(gamma,tf.float32)
voxels = tf.transpose(voxels,perm=[1,2,3,0,4])
dimensions = tf.shape(voxels)[0]
x = tf.cast(tf.range(-tf.cast(tf.floor(tf.cast(dimensions,tf.float32)/2.0),tf.int32),tf.cast(tf.cast(dimensions,tf.float32)-tf.floor(tf.cast(dimensions,tf.float32)/2.0),tf.int32)),tf.float32)
X,Y,Z = tf.meshgrid(x,x,x)
tmp = tf.reshape(voxels,[-1])
X = tf.reshape(X,[-1]); Y = tf.reshape(Y,[-1]);Z = tf.reshape(Z,[-1])
X_mass = tf.round(tf.reduce_mean(X*tmp))
Y_mass = tf.round(tf.reduce_mean(Y*tmp))
Z_mass = tf.round(tf.reduce_mean(Z*tmp))
coordinates = tf.stack([X,Y,Z],axis=0)
kernel = tf.cast(coordinates,tf.float32)
channels = tf.shape(voxels)[-1]
batchdim = tf.shape(voxels)[-2]
rotation_matrix_x =tf.stack([tf.constant(1.0),tf.constant(0.0),tf.constant(0.0),
tf.constant(0.0),tf.cos(alpha), -tf.sin(alpha),
......@@ -93,20 +81,20 @@ class transform_3D(tf.keras.layers.Layer):
rotation_matrix_z = tf.reshape(rotation_matrix_z, (3,3))
s = tf.matmul(rotation_matrix_x,tf.matmul(rotation_matrix_y,rotation_matrix_z))
r = tf.matmul(tf.matmul(rotation_matrix_x,tf.matmul(rotation_matrix_y,rotation_matrix_z)) ,kernel)
r = tf.matmul(tf.matmul(rotation_matrix_x,tf.matmul(rotation_matrix_y,rotation_matrix_z)) ,self.kernel)
x,y,z = tf.split(r,3,axis=0)
X = tf.reshape(x,[-1])
Y = tf.reshape(y,[-1])
Z = tf.reshape(z,[-1])
X_lower = tf.math.floor(X)-X_mass # tf.clip_by_value(tf.math.floor(X),0,self.dimensions)
X_upper = tf.math.ceil(X)-X_mass # tf.clip_by_value(tf.math.ceil(X),0,self.dimensions)
X_lower = tf.math.floor(X) # tf.clip_by_value(tf.math.floor(X),0,self.dimensions)
X_upper = tf.math.ceil(X) # tf.clip_by_value(tf.math.ceil(X),0,self.dimensions)
Y_lower = tf.math.floor(Y)-Y_mass # tf.clip_by_value(tf.math.floor(Y),0,self.dimensions)
Y_upper = tf.math.ceil(Y)-Y_mass # tf.clip_by_value(tf.math.ceil(Y),0,self.dimensions)
Y_lower = tf.math.floor(Y) # tf.clip_by_value(tf.math.floor(Y),0,self.dimensions)
Y_upper = tf.math.ceil(Y) # tf.clip_by_value(tf.math.ceil(Y),0,self.dimensions)
Z_lower = tf.math.floor(Z)-Z_mass # tf.clip_by_value(tf.math.floor(Z),0,self.dimensions)
Z_upper = tf.math.ceil(Z)-Z_mass #tf.clip_by_value(tf.math.ceil(Z),0,self.dimensions)
Z_lower = tf.math.floor(Z) # tf.clip_by_value(tf.math.floor(Z),0,self.dimensions)
Z_upper = tf.math.ceil(Z) #tf.clip_by_value(tf.math.ceil(Z),0,self.dimensions)
x_d = (X-X_lower+0.001)/(X_upper-X_lower+0.001)
y_d = (Y-Y_lower+0.001)/(Y_upper-Y_lower+0.001)
......@@ -129,6 +117,10 @@ class transform_3D(tf.keras.layers.Layer):
c010 = tf.gather_nd(voxels,tf.cast(coord_010,tf.int32))
c110 = tf.gather_nd(voxels,tf.cast(coord_110,tf.int32))
x_d = tf.expand_dims(tf.expand_dims(x_d,axis=1),axis=1)
y_d = tf.expand_dims(tf.expand_dims(y_d,axis=1),axis=1)
z_d = tf.expand_dims(tf.expand_dims(z_d,axis=1),axis=1)
c00 = c000*(1-x_d) + c100*x_d
c01 = c001*(1-x_d) + c101*x_d
c10 = c010*(1-x_d) + c110*x_d
......@@ -138,11 +130,16 @@ class transform_3D(tf.keras.layers.Layer):
c1 = c01*(1-y_d) + c11*y_d
c = c0*(1-z_d)+c1*z_d
c = tf.transpose(c,perm=[1,0,2])
out = tf.reshape(c,[batchdim,dimensions,dimensions,dimensions,channels])
return tf.reduce_sum(tf.reshape(c,[dimensions,dimensions,dimensions]),axis=0)
return out
def call(self,voxels,alpha,beta,gamma):
image =tf.expand_dims(tf.map_fn(lambda x: self.rotation_map(tf.squeeze(voxels),x[0],x[1],x[2]),[alpha,beta,gamma],dtype=tf.float32),axis=-1)
s = tf.shape(image)
image = tf.reshape(image,[s[0]*s[1],s[2],s[3],s[4],s[5]])
return image
......@@ -311,11 +308,13 @@ def apply_ctf(image,ctf_params,gamma,spherical_abberation,w_2):
return tf.cast(image,tf.float32)
def z_maker(model,num_linear_params):
k = np.random.random(num_linear_params)
#k = k/np.sum(k)
z = np.dot(model,k)
s = np.cbrt(z.shape[0]).astype(np.int32)
z = np.reshape(z,[1,s,s,s,1])
z = np.concatenate([z]*10,axis=0)
return z,k
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment