Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Jonathan Juhl
SortEM
Commits
a54a942b
Commit
a54a942b
authored
May 25, 2021
by
Jonathan Juhl
Browse files
correct loss function
parent
90bb6c10
Changes
2
Hide whitespace changes
Inline
Side-by-side
fac_sortem.py
View file @
a54a942b
...
...
@@ -55,21 +55,6 @@ class DynAE(super_class):
return
prob_density
@
tf
.
function
def
interpolation
(
class_centroids
,
num_interpolants
):
normalized
=
tf
.
l2_norm
(
class_centroids
)
tmp
=
tf
.
matmul
(
normalized
,
normalized
,
transpose_b
=
True
)
t_target
=
tf
.
argmin
(
tmp
,
axis
=
1
)
t_source
=
tf
.
argmin
(
tf
.
rduce_min
(
tmp
,
axis
=
1
))
lin_space
=
tf
.
linspace
(
0
,
1
,
self
.
interpolation_num_samples
)
strait_line_vector
=
normalized
[
t_target
]
*
lin_space
+
(
1
-
lin_space
)
*
normalized
[
t_source
]
softweighted_vectors
=
tf
.
matmul
(
tf
.
nn
.
softmax
(
tf
.
matmul
(
strait_line_vector
,
normalized
,
transpose_b
=
True
)),
class_centroids
,
transpose_b
=
True
)
return
softweighted_vectors
@
tf
.
function
def
predict_cluster
(
self
,
num_classes
,
angular
,
images
):
num_classes
=
tf
.
one_hot
(
num_classes
,
self
.
num_parts
)
...
...
@@ -132,12 +117,12 @@ class DynAE(super_class):
lz
=
loss_latent
(
predict_z_f
,
catagorial
)
loss_tot
+=
lz
lg
=
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
h5_sig_f
,
tf
.
ones_like
(
h5_sig_f
)))
lg
=
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
tf
.
ones_like
(
h5_sig_f
)
,
h5_sig_f
))
loss_tot
+=
lg
grad
=
t
.
gradient
(
loss_tot
,
self
.
G
.
trainable_variables
+
self
.
cluster_Layer
.
trainable_variables
)
self
.
g_opt
.
apply_gradients
(
zip
(
grad
,
self
.
G
.
trainable_variables
+
self
.
cluster_Layer
.
trainable_variables
))
return
l
z
,
lg
,
add_images
return
l
g
,
lg
,
add_images
@
tf
.
function
def
train_d
(
self
,
x_real
,
take_components
,
angular
,
ctf_pr_count
,
t_x_pr_count
,
t_y_pr_count
,
t_z_pr_count
,
inplane_count
,
lambdas
,
spher_abb
,
ac
):
...
...
@@ -223,25 +208,10 @@ class DynAE(super_class):
for
i
in
range
(
2
):
takes
=
np
.
random
.
choice
(
np
.
arange
(
self
.
num_parts
),
self
.
batch_size
)
angular
=
np
.
random
.
choice
(
np
.
arange
(
self
.
angular_cluster
**
2
),
self
.
batch_size
)
ctf_pr_count
,
t_x_pr_count
,
t_y_pr_count
,
t_z_pr_count
,
inplane_count
=
draw_from_distribution
(
self
.
large_rescale
,
ctf_params
,
self
.
batch_size
)
if
self
.
verbose
:
image
,
y
=
next
(
dist_it
)
image
=
image
.
numpy
()
y
=
y
.
numpy
()
else
:
image
=
next
(
dist_it
)
image
=
image
.
numpy
()
image
=
tf
.
cast
(
image
,
self
.
precision
)
loss_d
,
loss_style
,
loss_z
,
images
=
strategy
.
run
(
self
.
train_d
,(
image
,
takes
,
angular
,
ctf_pr_count
,
t_x_pr_count
,
t_y_pr_count
,
t_z_pr_count
,
inplane_count
,
lambdas
,
spher_abb
,
ac
))
# loss_1,
l_list
.
append
([
loss_d
,
loss_style
,
loss_z
])
takes
=
np
.
random
.
choice
(
np
.
arange
(
self
.
num_parts
),
self
.
batch_size
)
angular
=
np
.
random
.
choice
(
np
.
arange
(
self
.
angular_cluster
**
2
),
self
.
batch_size
)
ctf_pr_count
,
t_x_pr_count
,
t_y_pr_count
,
t_z_pr_count
,
inplane_count
=
draw_from_distribution
(
self
.
large_rescale
,
ctf_params
,
self
.
batch_size
)
if
self
.
verbose
:
image
,
y
=
next
(
dist_it
)
image
=
image
.
numpy
()
...
...
@@ -250,15 +220,32 @@ class DynAE(super_class):
image
=
next
(
dist_it
)
image
=
image
.
numpy
()
takes
=
np
.
random
.
choice
(
np
.
arange
(
self
.
num_parts
),
self
.
batch_size
)
angular
=
np
.
random
.
choice
(
np
.
arange
(
self
.
angular_cluster
**
2
),
self
.
batch_size
)
ctf_pr_count
,
t_x_pr_count
,
t_y_pr_count
,
t_z_pr_count
,
inplane_count
=
draw_from_distribution
(
self
.
large_rescale
,
ctf_params
,
self
.
batch_size
)
loss_g
,
loss_q
,
images
=
strategy
.
run
(
self
.
train_g
,(
image
,
takes
,
angular
,
ctf_pr_count
,
t_x_pr_count
,
t_y_pr_count
,
t_z_pr_count
,
inplane_count
,
lambdas
,
spher_abb
,
ac
))
image
=
tf
.
cast
(
image
,
self
.
precision
)
loss_d
,
loss_style
,
loss_z
,
images
=
strategy
.
run
(
self
.
train_d
,(
image
,
takes
,
angular
,
ctf_pr_count
,
t_x_pr_count
*
0
,
t_y_pr_count
*
0
,
t_z_pr_count
*
0
,
inplane_count
*
0
,
lambdas
,
spher_abb
,
ac
))
# loss_1,
loss_d
=
loss_d
.
numpy
()
loss_style
=
loss_style
.
numpy
()
loss_style
=
loss_style
.
numpy
()
loss_z
=
loss_z
.
numpy
()
print
(
'training time: '
,
time
()
-
t0
,
"step: %i of %i "
%
(
ite
,
self
.
steps
),
loss_d
,
loss_style
,
loss_z
)
l_list
.
append
([
loss_d
,
loss_style
,
loss_z
])
if
self
.
verbose
:
image
,
y
=
next
(
dist_it
)
image
=
image
.
numpy
()
y
=
y
.
numpy
()
else
:
image
=
next
(
dist_it
)
image
=
image
.
numpy
()
for
i
in
range
(
4
):
takes
=
np
.
random
.
choice
(
np
.
arange
(
self
.
num_parts
),
self
.
batch_size
)
angular
=
np
.
random
.
choice
(
np
.
arange
(
self
.
angular_cluster
**
2
),
self
.
batch_size
)
ctf_pr_count
,
t_x_pr_count
,
t_y_pr_count
,
t_z_pr_count
,
inplane_count
=
draw_from_distribution
(
self
.
large_rescale
,
ctf_params
,
self
.
batch_size
)
loss_q
,
loss_g
,
images
=
strategy
.
run
(
self
.
train_g
,(
image
,
takes
,
angular
,
ctf_pr_count
*
0
,
t_x_pr_count
*
0
,
t_y_pr_count
*
0
,
t_z_pr_count
*
0
,
inplane_count
*
0
,
lambdas
,
spher_abb
,
ac
))
loss_g
=
loss_g
.
numpy
()
print
(
'training time: '
,
time
()
-
t0
,
"step: %i of %i "
%
(
ite
,
self
.
steps
),
loss_d
,
loss_style
,
loss_z
,
loss_g
)
take
=
[]
#print(ite % self.validate_interval)
...
...
utils_sortem.py
View file @
a54a942b
import
tensorflow
as
tf
import
matplotlib
matplotlib
.
use
(
'Agg'
)
...
...
@@ -47,22 +49,22 @@ def loss_latent(predict,catagories):
def
loss_gen
(
D_logits_
,
predict_z
,
z
):
g_loss
=
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
D_logits_
,
tf
.
ones_like
(
D_logits_
)))
+
loss_latent
(
predict_z
,
z
)
g_loss
=
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
tf
.
ones_like
(
D_logits_
)
,
D_logits_
))
+
loss_latent
(
predict_z
,
z
)
return
g_loss
def
loss_encode
(
d_h1_r
,
d_h1_f
,
d_h2_r
,
d_h2_f
,
d_h3_r
,
d_h3_f
,
d_h4_r
,
d_h4_f
):
d_h1_loss
=
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
d_h1_r
,
tf
.
ones_like
(
d_h1_r
)))
\
+
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
d_h1_f
,
tf
.
zeros_like
(
d_h1_f
)))
d_h1_loss
=
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
tf
.
ones_like
(
d_h1_r
)
,
d_h1_r
))
\
+
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
tf
.
zeros_like
(
d_h1_f
)
,
d_h1_f
))
d_h2_loss
=
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
d_h2_r
,
tf
.
ones_like
(
d_h2_r
)))
\
+
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
d_h2_f
,
tf
.
zeros_like
(
d_h2_f
)))
d_h2_loss
=
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
tf
.
ones_like
(
d_h2_r
)
,
d_h2_r
))
\
+
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
tf
.
zeros_like
(
d_h2_f
)
,
d_h2_f
))
d_h3_loss
=
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
d_h3_r
,
tf
.
ones_like
(
d_h3_r
)))
\
+
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
d_h3_f
,
tf
.
zeros_like
(
d_h3_f
)))
d_h3_loss
=
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
tf
.
ones_like
(
d_h3_r
)
,
d_h3_r
))
\
+
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
tf
.
zeros_like
(
d_h3_f
)
,
d_h3_f
))
d_h4_loss
=
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
d_h4_r
,
tf
.
ones_like
(
d_h4_r
)))
\
+
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
d_h4_f
,
tf
.
zeros_like
(
d_h4_f
)))
d_h4_loss
=
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
tf
.
ones_like
(
d_h4_r
)
,
d_h4_r
))
\
+
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
tf
.
zeros_like
(
d_h4_f
)
,
d_h4_f
))
return
d_h1_loss
+
d_h2_loss
+
d_h3_loss
+
d_h4_loss
...
...
@@ -70,8 +72,8 @@ def loss_encode(d_h1_r,d_h1_f,d_h2_r,d_h2_f,d_h3_r,d_h3_f,d_h4_r,d_h4_f):
def
loss_disc
(
D_logits
,
D_logits_fake
,
predict_z
,
z
):
z
=
tf
.
cast
(
z
,
tf
.
float32
)
d_loss_real
=
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
D_logits
,
tf
.
ones_like
(
D_logits
)))
d_loss_fake
=
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
D_logits_fake
,
tf
.
zeros_like
(
D_logits_fake
)))
d_loss_real
=
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
tf
.
ones_like
(
D_logits
)
,
D_logits
))
d_loss_fake
=
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
tf
.
zeros_like
(
D_logits_fake
)
,
D_logits_fake
))
d_loss
=
d_loss_real
+
d_loss_fake
...
...
@@ -87,7 +89,9 @@ class transform_3D(tf.keras.layers.Layer):
self
.
full_image
=
full_image
def
build
(
self
,
input_shape
):
self
.
dimensions
=
input_shape
[
1
]
self
.
dimensions
=
input_shape
[
-
2
]
self
.
channels
=
input_shape
[
-
1
]
x
=
tf
.
range
(
-
int
(
np
.
floor
(
self
.
dimensions
/
2
)),
int
(
np
.
ceil
(
self
.
dimensions
/
2
)))
X
,
Y
,
Z
=
tf
.
meshgrid
(
x
,
x
,
x
)
...
...
@@ -104,9 +108,7 @@ class transform_3D(tf.keras.layers.Layer):
y_translate
=
tf
.
cast
(
y_translate
,
tf
.
float32
)
z_translate
=
tf
.
cast
(
z_translate
,
tf
.
float32
)
#voxels = tf.transpose(voxels,perm=[1,2,3,0,4])
dimensions
=
tf
.
shape
(
voxels
)[
0
]
channels
=
tf
.
shape
(
voxels
)[
-
1
]
#batchdim = tf.shape(voxels)[-2]
rotation_matrix_x
=
tf
.
stack
([
tf
.
ones_like
(
alpha
),
tf
.
zeros_like
(
alpha
),
tf
.
zeros_like
(
alpha
),
...
...
@@ -127,11 +129,11 @@ class transform_3D(tf.keras.layers.Layer):
rotation_matrix_y
=
tf
.
reshape
(
rotation_matrix_y
,
(
3
,
3
))
rotation_matrix_z
=
tf
.
reshape
(
rotation_matrix_z
,
(
3
,
3
))
s
=
tf
.
matmul
(
rotation_matrix_x
,
tf
.
matmul
(
rotation_matrix_y
,
rotation_matrix_z
))
r
=
tf
.
matmul
(
tf
.
matmul
(
rotation_matrix_x
,
tf
.
matmul
(
rotation_matrix_y
,
rotation_matrix_z
))
,
self
.
kernel
)
x
,
y
,
z
=
tf
.
split
(
r
,
3
,
axis
=
0
)
X
=
tf
.
reshape
(
x
,[
-
1
])
+
x_translate
*
(
self
.
dimensions
/
self
.
full_image
)
Y
=
tf
.
reshape
(
y
,[
-
1
])
+
y_translate
*
(
self
.
dimensions
/
self
.
full_image
)
Z
=
tf
.
reshape
(
z
,[
-
1
])
+
z_translate
*
(
self
.
dimensions
/
self
.
full_image
)
...
...
@@ -148,19 +150,19 @@ class transform_3D(tf.keras.layers.Layer):
y_d
=
(
Y
-
Y_lower
+
0.001
)
/
(
Y_upper
-
Y_lower
+
0.001
)
z_d
=
(
Z
-
Z_lower
+
0.001
)
/
(
Z_upper
-
Z_lower
+
0.001
)
coord_000
=
tf
.
stack
([
X_lower
,
Y_lower
,
Z_lower
],
axis
=
1
)
+
tf
.
cast
(
tf
.
floor
(
dimensions
/
2
),
tf
.
float32
)
coord_001
=
tf
.
stack
([
X_lower
,
Y_lower
,
Z_upper
],
axis
=
1
)
+
tf
.
cast
(
tf
.
floor
(
dimensions
/
2
),
tf
.
float32
)
coord_011
=
tf
.
stack
([
X_lower
,
Y_upper
,
Z_upper
],
axis
=
1
)
+
tf
.
cast
(
tf
.
floor
(
dimensions
/
2
),
tf
.
float32
)
coord_111
=
tf
.
stack
([
X_upper
,
Y_upper
,
Z_upper
],
axis
=
1
)
+
tf
.
cast
(
tf
.
floor
(
dimensions
/
2
),
tf
.
float32
)
coord_101
=
tf
.
stack
([
X_upper
,
Y_lower
,
Z_upper
],
axis
=
1
)
+
tf
.
cast
(
tf
.
floor
(
dimensions
/
2
),
tf
.
float32
)
coord_100
=
tf
.
stack
([
X_upper
,
Y_lower
,
Z_lower
],
axis
=
1
)
+
tf
.
cast
(
tf
.
floor
(
dimensions
/
2
),
tf
.
float32
)
coord_010
=
tf
.
stack
([
X_lower
,
Y_upper
,
Z_lower
],
axis
=
1
)
+
tf
.
cast
(
tf
.
floor
(
dimensions
/
2
),
tf
.
float32
)
coord_110
=
tf
.
stack
([
X_upper
,
Y_upper
,
Z_lower
],
axis
=
1
)
+
tf
.
cast
(
tf
.
floor
(
dimensions
/
2
),
tf
.
float32
)
coord_000
=
tf
.
stack
([
X_lower
,
Y_lower
,
Z_lower
],
axis
=
1
)
+
tf
.
cast
(
tf
.
floor
(
self
.
dimensions
/
2
),
tf
.
float32
)
coord_001
=
tf
.
stack
([
X_lower
,
Y_lower
,
Z_upper
],
axis
=
1
)
+
tf
.
cast
(
tf
.
floor
(
self
.
dimensions
/
2
),
tf
.
float32
)
coord_011
=
tf
.
stack
([
X_lower
,
Y_upper
,
Z_upper
],
axis
=
1
)
+
tf
.
cast
(
tf
.
floor
(
self
.
dimensions
/
2
),
tf
.
float32
)
coord_111
=
tf
.
stack
([
X_upper
,
Y_upper
,
Z_upper
],
axis
=
1
)
+
tf
.
cast
(
tf
.
floor
(
self
.
dimensions
/
2
),
tf
.
float32
)
coord_101
=
tf
.
stack
([
X_upper
,
Y_lower
,
Z_upper
],
axis
=
1
)
+
tf
.
cast
(
tf
.
floor
(
self
.
dimensions
/
2
),
tf
.
float32
)
coord_100
=
tf
.
stack
([
X_upper
,
Y_lower
,
Z_lower
],
axis
=
1
)
+
tf
.
cast
(
tf
.
floor
(
self
.
dimensions
/
2
),
tf
.
float32
)
coord_010
=
tf
.
stack
([
X_lower
,
Y_upper
,
Z_lower
],
axis
=
1
)
+
tf
.
cast
(
tf
.
floor
(
self
.
dimensions
/
2
),
tf
.
float32
)
coord_110
=
tf
.
stack
([
X_upper
,
Y_upper
,
Z_lower
],
axis
=
1
)
+
tf
.
cast
(
tf
.
floor
(
self
.
dimensions
/
2
),
tf
.
float32
)
#voxels = tf.reshape(voxels,[dimensions**3,channels])
c000
=
tf
.
gather_nd
(
voxels
,
tf
.
cast
(
coord_000
,
tf
.
int32
))
# print(c000);exit()
c001
=
tf
.
gather_nd
(
voxels
,
tf
.
cast
(
coord_001
,
tf
.
int32
))
c011
=
tf
.
gather_nd
(
voxels
,
tf
.
cast
(
coord_011
,
tf
.
int32
))
...
...
@@ -175,7 +177,7 @@ class transform_3D(tf.keras.layers.Layer):
z_d
=
tf
.
expand_dims
(
z_d
,
axis
=
1
)
c00
=
c000
*
(
1
-
x_d
)
+
c100
*
x_d
c01
=
c001
*
(
1
-
x_d
)
+
c101
*
x_d
c10
=
c010
*
(
1
-
x_d
)
+
c110
*
x_d
c11
=
c011
*
(
1
-
x_d
)
+
c111
*
x_d
...
...
@@ -186,8 +188,8 @@ class transform_3D(tf.keras.layers.Layer):
c
=
c0
*
(
1
-
z_d
)
+
c1
*
z_d
out
=
tf
.
reshape
(
c
,[
dimensions
,
dimensions
,
dimensions
,
channels
])
out
=
tf
.
reshape
(
c
,[
self
.
dimensions
,
self
.
dimensions
,
self
.
dimensions
,
self
.
channels
])
return
out
def
call
(
self
,
voxels
,
alpha
,
beta
,
gamma
,
x_translate
,
y_translate
,
z_translate
):
...
...
@@ -384,9 +386,17 @@ def apply_ctf(image,ctf_params,KVolts,spherical_abberation,w2):
return
ctf_image
#import mrcfile
"""
data = np.load('data.npy')
data = np.reshape(data,(1,256,256,256,1))
t = transform_3D(256)
v = np.asarray([0.0])
k = np.asarray([128])
out = t(data,v,v,v,k,k,v)
plt.imshow(np.squeeze(np.sum(out,axis=1)))
plt.savefig('try_it.png')
#import mrcfile"""
"""V = 300
lambdas = 10**(-4)*10**(-6)*12.25*10**(-10)/np.sqrt(V*10**3)
...
...
@@ -431,4 +441,4 @@ with mrcfile.open('/emcc/misser11/EMPIAR_10317/out_noise.mrcs') as mrc:
plt.clf()
t+=1
"""
\ No newline at end of file
"""
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment