Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Jonathan Juhl
SortEM
Commits
6f43a18f
Commit
6f43a18f
authored
Dec 01, 2021
by
Jonathan Juhl
Browse files
Delete image_restoration_sortem.py
parent
214d1be5
Changes
1
Hide whitespace changes
Inline
Side-by-side
image_restoration_sortem.py
deleted
100644 → 0
View file @
214d1be5
import
matplotlib
matplotlib
.
use
(
'Agg'
)
import
tensorflow
as
tf
from
tensorflow.keras.layers
import
Conv2D
,
UpSampling2D
,
Activation
,
GlobalAveragePooling2D
,
PReLU
,
Conv2DTranspose
,
BatchNormalization
,
Flatten
import
matplotlib.pyplot
as
plt
plt
.
switch_backend
(
'agg'
)
from
mrc_loader_sortem
import
mrc_loader
import
numpy
as
np
from
os.path
import
join
from
os.path
import
isfile
from
super_clas_sortem
import
super_class
from
models
import
Unet
class
Denoise
(
super_class
):
def
__init__
(
self
,
parameter_file_path
,
not_projected_star_files
,
projected_star_files
,
bytes_pr_record
,
validate
=
None
,
val_bytes
=
None
):
super_class
.
__init__
(
self
,
parameter_file_path
)
self
.
not_projected_star_files
=
not_projected_star_files
self
.
validate
=
validate
self
.
val_bytes
=
val_bytes
self
.
projected_star_files
=
projected_star_files
self
.
paths
=
[
self
.
not_projected_star_files
,
self
.
projected_star_files
]
self
.
bytes_pr_record
=
bytes_pr_record
num_steps
=
(
self
.
max_particles
/
(
self
.
batch_size
*
self
.
num_gpus
))
*
self
.
epochs
self
.
steps
=
num_steps
self
.
opt
=
self
.
optimizer
(
self
.
steps
)
self
.
unet
=
Unet
()
self
.
train
()
@
tf
.
function
def
mask
(
self
,
image
):
image
=
tf
.
squeeze
(
image
)
x
=
tf
.
range
(
64
)
s
=
tf
.
random
.
shuffle
(
x
)
_
,
y
=
tf
.
meshgrid
(
x
,
x
)
rs
=
tf
.
reduce_sum
(
tf
.
cast
(
tf
.
equal
(
y
,
tf
.
roll
(
s
,
int
(
64
/
2
),
axis
=
0
)),
self
.
precision
)
*
image
,
axis
=
2
)
selected_pixels
=
tf
.
cast
(
tf
.
equal
(
y
,
s
),
self
.
precision
)
not_selected_pixels
=
tf
.
cast
(
tf
.
not_equal
(
y
,
s
),
self
.
precision
)
m_image
=
tf
.
transpose
(
tf
.
transpose
(
tf
.
stack
([
selected_pixels
]
*
self
.
batch_size
,
axis
=
0
),
perm
=
[
1
,
0
,
2
])
*
rs
,
perm
=
[
1
,
0
,
2
])
return
m_image
+
not_selected_pixels
*
image
,
selected_pixels
@
tf
.
function
def
loss
(
self
,
image
,
estimate
,
mask
):
a
=
Flatten
()(
image
)
b
=
Flatten
()(
estimate
)
c
=
tf
.
reshape
(
mask
,[
-
1
])
mean
=
tf
.
reduce_mean
(((
a
-
b
)
*
c
)
**
2
)
return
mean
def
plotlib
(
self
,
image
,
raw_image
):
s
=
np
.
concatenate
(
np
.
split
(
image
,
image
.
shape
[
0
]),
axis
=
1
)
t
=
np
.
concatenate
(
np
.
split
(
raw_image
,
raw_image
.
shape
[
0
]),
axis
=
1
)
plt
.
imshow
(
np
.
squeeze
(
np
.
concatenate
([
s
,
t
],
axis
=
2
)),
cmap
=
'gray'
)
plt
.
savefig
(
join
(
self
.
results
,
'image_signal.png'
),)
@
tf
.
function
def
predict_net_low
(
self
,
raw_data
):
stage3_img
=
self
.
unet
(
raw_data
)
return
stage3_img
@
tf
.
function
def
train_net_L
(
self
,
raw_data_image
):
raw_data_image
=
tf
.
cast
(
raw_data_image
,
self
.
precision
)
swaped_pixels
,
mask
=
self
.
mask
(
raw_data_image
)
#plt.imshow(swaped_pixels[0])
#plt.savefig('test.png');exit()
swaped_pixels
=
tf
.
expand_dims
(
swaped_pixels
,
axis
=-
1
)
with
tf
.
GradientTape
()
as
tape
:
estimate
=
self
.
unet
(
swaped_pixels
)
loss
=
self
.
loss
(
raw_data_image
,
estimate
,
mask
)
variables
=
self
.
unet
.
trainable_weights
self
.
apply_grad
(
loss
,
variables
,
tape
)
return
loss
def
train
(
self
):
strategy
,
distribute
=
self
.
generator
(
'contrastive'
,
self
.
validate
,
self
.
val_bytes
,
self
.
batch_size
)
dis
=
iter
(
distribute
)
if
self
.
validate
!=
None
:
strategy_validate
,
distribute_val
=
self
.
generator
(
'predict'
,
self
.
validate
,
self
.
val_bytes
,
self
.
predict_batch_size
)
dis_val
=
iter
(
distribute_val
)
pred_data
=
next
(
dis_val
)
if
self
.
verbose
:
pred_data
,
y
=
pred_data
if
not
isfile
(
join
(
self
.
models
,
'unet.index'
)):
ite
=
1
while
True
:
raw_data_image
,
perm
=
next
(
dis
)
loss
=
strategy
.
run
(
self
.
train_net_L
,
args
=
(
perm
,))
if
ite
%
self
.
validate_interval
==
0
and
self
.
validate
:
validation_images
=
strategy_validate
.
run
(
self
.
predict_net_low
,
args
=
(
pred_data
,))
self
.
plotlib
(
validation_images
,
pred_data
)
self
.
unet
.
save_weights
(
join
(
self
.
models
,
'unet'
))
print
(
"step:%i of %i"
%
(
ite
,
self
.
steps
),
loss
.
numpy
())
ite
+=
1
if
self
.
steps
>
ite
:
break
\ No newline at end of file
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment