utils.py 3.12 KB
Newer Older
Jonathan Juhl's avatar
all  
Jonathan Juhl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

import tensorflow as tf
import numpy as np

def draw_from_distribution(width,ctf_params,projections_count):

    translate_x,translate_y = np.arange(width)

    gauss_prob = np.random.normal(0,scale*(width/2))

    gauss_prob = gauss_prob/np.sum(gauss_prob)

    a = np.arange(ctf_params.shape[0])
    sa = np.random.choice(a,projections_count)

    ctf_pr_count = np.take_along_axis(ctf_params,sa,axis=0)
    t_x_pr_count = np.random.choice(translate_x,projections_count,p=gauss_prob)
    t_y_pr_count = np.random.choice(translate_y,projections_count,p=gauss_prob)
    psi_count =          np.random.choice(np.linspace(0,1,36),projections_count)
    rho_count =          np.random.choice(np.linspace(0,1,36),projections_count)
    inplane_count  =     np.random.choice(np.linspace(0,1,36),projections_count)
    return ctf_pr_count,t_x_pr_count,t_y_pr_count,psi_count,rho_count,inplane_count

@tf.function
def projection(model,gamma,beta,alpha):
    rotated_model = tfa.image.rotate(model,gamma*2*3.1415)
    rotated_model = tf.transpose(rotated_model,perm=[0,2,1])
    rotated_model = tfa.image.rotate(rotated_model,beta*2*3.1415)
    rotated_model = tf.transpose(rotated_model,perm=[2,1,0])
    return  tfa.image.rotate(tf.reduce_sum(rotated_model,axis=1),alpha*2.31415)


@tf.function
def translate(image_stack,translate_x,translate_y):
    shifted_images = tfa.image.translate(image_stack,tf.stack([translate_y,translate_x]))
    return shifted_images

def mm_models(model_size,num_linear_params):

    model_vectors = np.random.uniform(low=0.0, high=1.0, size=[model_size**3,num_linear_params])
    q,r = linalg.qr(model_vectors)
    return q

def sample_params(batch_size):
    a  = np.random.unfiorm(shape=[batch_size])
    a = a/np.reduce_sum(a,axis=1)
    return a

@tf.function
def apply_ctf(image,scale_dimension,ctf_params,gamma,spherical_abberation,w_2):
     

    phase,defocus_1,defocus_2,ast = tf.split(ctf_params,4,axis=1)

    w_1 = tf.sqrt(1-w_2)

    linear_map = tf.linspace(-1+0.001,1+0.001,scale_dimension)
    X,Y = tf.meshgrid(linear_map,linear_map)
    a_g = tf.math.tan(Y/X)
    g = X**2+Y**2
    delta_f = 0.5*(defocus_1+defocus_2+(defocus_1-defocus_2)*tf.cos(2*(a_g-ast)))

    t = tf.math.atan(w_2/w_1)
    xi  = 3.1415*gamma*g*(delta_f-0.5*gamma**2*g*spherical_abberation) + phase + t
    CTF = -tf.math.sin(xi)

    ctf_image  = tf.signal.ifft2d(tf.signal.fftshift(tf.signal.fftshift(ttf.signal.fft2d(image),axes=(1,2))*CTF,axes=(1,2)))

    return ctf_image

def z_maker(model,num_linear_params):
    z = np.dot(model_size,num_linear_params)
    z  = np.reshape(z,(num_linear_params,model.shape[0],model.shape[0],model.shape[0]))
    return z

@tf.function
def project_a_model(self,model):
    side_1 = tf.reduce_sum(model,axis=1)
    side_2 = tf.reduce_sum(model,axis=2)
    side_3 = tf.reduce_sum(model,axis=3)
    axis_one = tf.concat([tf.zeros([scale,scale]),side_1],axis=2)
    axis_two = tf.concat([side_2,side_3],axis=2)
    combined = tf.concat([axis_one,axis_two],axis=1)
    combined = tf.concat([tf.split(combined,tf.shape(combined)[0])],axis=0)
    return combined