Commit 6f43a18f authored by Jonathan Juhl's avatar Jonathan Juhl
Browse files

Delete image_restoration_sortem.py

parent 214d1be5
import matplotlib
matplotlib.use('Agg')
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, UpSampling2D,Activation,GlobalAveragePooling2D,PReLU,Conv2DTranspose,BatchNormalization,Flatten
import matplotlib.pyplot as plt
plt.switch_backend('agg')
from mrc_loader_sortem import mrc_loader
import numpy as np
from os.path import join
from os.path import isfile
from super_clas_sortem import super_class
from models import Unet
class Denoise(super_class):
def __init__(self,parameter_file_path,not_projected_star_files,projected_star_files,bytes_pr_record,validate=None,val_bytes=None):
super_class.__init__(self ,parameter_file_path)
self.not_projected_star_files = not_projected_star_files
self.validate = validate
self.val_bytes = val_bytes
self.projected_star_files = projected_star_files
self.paths = [self.not_projected_star_files,self.projected_star_files]
self.bytes_pr_record = bytes_pr_record
num_steps = (self.max_particles/(self.batch_size*self.num_gpus))*self.epochs
self.steps = num_steps
self.opt = self.optimizer(self.steps)
self.unet = Unet()
self.train()
@tf.function
def mask(self,image):
image = tf.squeeze(image)
x = tf.range(64)
s = tf.random.shuffle(x)
_,y = tf.meshgrid(x,x)
rs = tf.reduce_sum(tf.cast(tf.equal(y,tf.roll(s,int(64/2),axis=0)),self.precision)*image,axis=2)
selected_pixels = tf.cast(tf.equal(y,s),self.precision)
not_selected_pixels = tf.cast(tf.not_equal(y,s),self.precision)
m_image = tf.transpose(tf.transpose(tf.stack([selected_pixels]*self.batch_size,axis=0),perm=[1,0,2])*rs,perm=[1,0,2])
return m_image+not_selected_pixels*image,selected_pixels
@tf.function
def loss(self,image,estimate,mask):
a = Flatten()(image)
b = Flatten()(estimate)
c = tf.reshape(mask,[-1])
mean = tf.reduce_mean(((a-b)*c)**2)
return mean
def plotlib(self,image,raw_image):
s = np.concatenate(np.split(image,image.shape[0]),axis=1)
t = np.concatenate(np.split(raw_image,raw_image.shape[0]),axis=1)
plt.imshow(np.squeeze(np.concatenate([s,t],axis=2)),cmap='gray')
plt.savefig(join(self.results,'image_signal.png'),)
@tf.function
def predict_net_low(self,raw_data):
stage3_img = self.unet(raw_data)
return stage3_img
@tf.function
def train_net_L(self,raw_data_image):
raw_data_image = tf.cast(raw_data_image,self.precision)
swaped_pixels,mask = self.mask(raw_data_image)
#plt.imshow(swaped_pixels[0])
#plt.savefig('test.png');exit()
swaped_pixels = tf.expand_dims(swaped_pixels,axis=-1)
with tf.GradientTape() as tape:
estimate= self.unet(swaped_pixels)
loss= self.loss(raw_data_image,estimate,mask)
variables = self.unet.trainable_weights
self.apply_grad(loss,variables,tape)
return loss
def train(self):
strategy,distribute = self.generator('contrastive',self.validate,self.val_bytes,self.batch_size)
dis = iter(distribute)
if self.validate != None:
strategy_validate,distribute_val = self.generator('predict',self.validate,self.val_bytes,self.predict_batch_size)
dis_val = iter(distribute_val)
pred_data = next(dis_val)
if self.verbose:
pred_data,y = pred_data
if not isfile(join(self.models,'unet.index')):
ite = 1
while True:
raw_data_image,perm = next(dis)
loss = strategy.run(self.train_net_L,args=(perm,))
if ite % self.validate_interval == 0 and self.validate:
validation_images = strategy_validate.run(self.predict_net_low,args=(pred_data,))
self.plotlib(validation_images,pred_data)
self.unet.save_weights(join(self.models,'unet'))
print("step:%i of %i" %(ite,self.steps),loss.numpy())
ite +=1
if self.steps > ite:
break
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment