Commit ecfe932b authored by Jonathan Juhl's avatar Jonathan Juhl
Browse files

minor bug fixes

parent cee88668
......@@ -83,18 +83,16 @@ class GAN_NERF():
data = next(gen)
if self.args['ctf']:
params = {'image':data[0][0],
params = {'image':tf.expand_dims(tf.squeeze(data[0][0]),axis=-1),
'ctf':data[0][1],
'alpha':data[1],
'index':data[2],
'shape':data[3],
'index':data[2],
}
else:
params = {'image':data[0],
params = {'image':tf.expand_dims(tf.squeeze(data[0]),axis=-1),
'ctf': None,
'alpha':data[1],
'index':data[2],
'shape':data[3],
'index':data[2],
}
if self.args['num_gpus'] == 1:
......
......@@ -39,7 +39,7 @@ def main():
parser.add_argument('--ctf', dest='ctf',action='store_true',default=False,help='Use CTF parameters for model.')
parser.add_argument('--noise', dest='noise',action='store_false',default=False ,help='Use the noise generator to generate and scale the noise')
parser.add_argument('--noise', dest='noise',action='store_true',default=False ,help='Use the noise generator to generate and scale the noise')
parser.add_argument('--steps',type=int,default=[10000,100000,100000,10000,100000], nargs='+',help='how many epochs( runs through the dataset) before termination')
......
......@@ -5,19 +5,24 @@ import tensorflow_addons as tfa
import numpy as np
from tensorflow.keras.activations import softplus
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Add,LayerNormalization,Flatten,LeakyReLU, Dense,Conv2D,Layer
from tensorflow.keras.layers import Add,LayerNormalization,Flatten,LeakyReLU, Dense,Conv2D,Layer,UpSampling2D,Conv2DTranspose
from utils_sortem import Poisson_Measure,AverageBlurPooling2D
# the generative loss
class AdapterBlock(Layer):
def __init__(self, output_channels):
super(AdapterBlock,self).__init__()
self.model = Sequential(
Conv2D(output_channels,1),
LeakyReLU(0.2)
)
self.model = Conv2D(output_channels,1,activation= LeakyReLU(0.2))
def call(self, input):
return self.model(input)
class UpAdapterBlock(Layer):
def __init__(self, output_channels):
super(UpAdapterBlock,self).__init__()
self.model = Conv2DTranspose(output_channels,1,activation= LeakyReLU(0.2))
def call(self, input):
return self.model(input)
class ED_Maker(Model):
......@@ -26,21 +31,37 @@ class ED_Maker(Model):
self.img_size_to_layer = {2:8, 4:7, 8:6, 16:5, 32:4, 64:3, 128:2, 256:1, 512:0}
self.fromRGB = fromRGB
self.final_layer = final_layer
self.fromRGB.reverse()
self.start = self.img_size_to_layer[shape]
self.lays = layers
def call(self,input,alpha):
x = self.fromRGB[ int(self.start) ](input)
lays = self.lays[:self.start ]
lays.reverse()
for i, layer in enumerate(lays):
def call(self,input):
x = self.fromRGB[start](input)
for i, layer in enumerate(self.layers[start:]):
if i == 1:
x = alpha * x + (1 - alpha) * self.fromRGB[start+1](tf.image.resize(input, [int(tf.shape(x)[1])/2,int(tf.shape(x)[1])/2], mode='nearest'))
x = alpha * x + (1 - alpha) * self.fromRGB[self.start-1](tf.image.resize(input, [int(tf.shape(x)[1]),int(tf.shape(x)[1])], method='nearest'))
x = layer(x)
out = tf.squeeze(self.final_layer(x))
x = self.final_layer(x).reshape(x.shape[0], 1)
return x
return out,x
def get_vars(self):
v = []
for i in self.lays[:self.start ]:
v+= i.trainable_variables
return v+self.fromRGB[self.start-1].trainable_variables+self.fromRGB[ int(self.start) ].trainable_variables
class ResidualConvBlock(Layer):
def __init__(self,inplanes, planes, kernel_size=3, stride=1, downsample=False, groups=1):
......@@ -55,15 +76,15 @@ class ResidualConvBlock(Layer):
self.proj = Conv2D(planes, 1) if inplanes != planes else None
self.downsample = downsample
self.UpSampling2D = UpSampling2D(2)
def forward(self, identity):
self.UpSampling2D_1 = AverageBlurPooling2D(2)
self.UpSampling2D_2 = AverageBlurPooling2D(2)
def call(self, identity):
y = self.network(identity)
if self.downsample: y = self.UpSampling2D (y)
if self.downsample: identity = self.UpSampling2D (identity)
if self.downsample: y = self.UpSampling2D_1(y)
if self.downsample: identity = self.UpSampling2D_2(identity)
identity = identity if self.proj is None else self.proj(identity)
y = (y + identity)/math.sqrt(2)
y = (y + identity)/np.sqrt(2)
return y
class ResidualUpBlock(Layer):
......@@ -79,15 +100,15 @@ class ResidualUpBlock(Layer):
self.proj = Conv2D(planes, 1) if inplanes != planes else None
self.downsample = downsample
self.averageblurpooling_2D = AverageBlurPooling2D(2)
def forward(self, identity):
self.UpSampling2D_1 = UpSampling2D(2)
self.UpSampling2D_2 = UpSampling2D(2)
def call(self, identity):
y = self.network(identity)
if self.downsample: y = self.averageblurpooling_2D(y)
if self.downsample: identity = self.averageblurpooling_2D(identity)
if self.downsample: y = self.UpSampling2D_1(y)
if self.downsample: identity = self.UpSampling2D_2(identity)
identity = identity if self.proj is None else self.proj(identity)
y = (y + identity)/math.sqrt(2)
y = (y + identity)/np.sqrt(2)
return y
......@@ -98,14 +119,14 @@ class ResidualUpBlock(Layer):
class Map(Layer):
def __init__(self,batch_size,image_size,ctf):
def __init__(self):
super(Map,self).__init__()
self.dense_map_0 = Dense(128,activation= LeakyReLU(0.2),kernel_initializer=tf.keras.initializers.HeNormal())
self.dense_map_1 = Dense(128,activation= LeakyReLU(0.2),kernel_initializer=tf.keras.initializers.HeNormal())
self.dense_map_2 = Dense(128,activation= LeakyReLU(0.2),kernel_initializer=tf.keras.initializers.HeNormal())
self.dense_map_3 = Dense(128,activation= LeakyReLU(0.2),kernel_initializer=tf.keras.initializers.HeNormal())
self.dense_map_4 = Dense(256*2*8)
def call(self,image):
def call(self,feature_vector):
x = self.dense_map_0(feature_vector)
x = self.dense_map_1(x)
......@@ -120,9 +141,9 @@ class Ray_maker(Layer):
def __init__(self):
super(Ray_maker,self).__init__()
self.L = 10*3*2
self.init_L = 10
self.L = self.init_L*3*2
self.map = Map()
self.dense_0 = Dense(256,kernel_initializer=tf.random_uniform_initializer(-np.sqrt(6/self.L),np.sqrt(6/self.L)))
self.dense_1 = Dense(256,kernel_initializer=tf.random_uniform_initializer(-np.sqrt(6/256),np.sqrt(6/256)))
self.dense_2 = Dense(256,kernel_initializer=tf.random_uniform_initializer(-np.sqrt(6/256),np.sqrt(6/256)))
......@@ -133,24 +154,34 @@ class Ray_maker(Layer):
self.dense_7 = Dense(256,kernel_initializer=tf.random_uniform_initializer(-np.sqrt(6/256),np.sqrt(6/256)))
self.sigma_dense = Dense(2)
def build(self,input_shape):
self.batch_size = input_shape[0]
@tf.function
def apply_phase(self,alpha,beta,inputs):
#inputs = tf.reshape(inputs,[self.batch_size,-1,256])
inputs = tf.transpose(inputs,perm=[1,0,2])
inputs = tf.add(tf.multiply(alpha,inputs),beta)
return tf.transpose(inputs,perm=[1,0,2])
@tf.function
def call(self,inputs):
def build_frequencies(self,coordinates):
coordinates = tf.expand_dims(coordinates,axis=2)
s = tf.shape(coordinates)[1]
powered = tf.expand_dims(tf.pow(tf.cast(2.0,dtype=coordinates.dtype),tf.cast(tf.range(self.init_L),dtype=coordinates.dtype)),axis=0)
out = tf.matmul(coordinates,powered,coordinates.dtype)
out = tf.reshape(out,[-1,s,self.init_L*3])
out = tf.concat([tf.cos(np.pi*out),tf.sin(np.pi*out)],axis=-1)
return out
@tf.function
def call(self,im_size,feature_vector,coordinates):
if self.ctf:
im_size,feature_vector,coordinates = inputs
else:
coordinates = inputs
# print(coordinates);exit()
coordinates = self.build_frequencies(coordinates)
x = tf.reshape(coordinates,[self.batch_size,-1,3])
lists = self.map(feature_vector)
x = tf.reshape(coordinates,[-1,(int(im_size/2)+1)*im_size,self.L])
x = self.dense_0(x)
x = self.apply_phase(lists[0],lists[1],x)
x = tf.sin(30*x)
......@@ -176,11 +207,12 @@ class Ray_maker(Layer):
x = self.apply_phase(lists[14],lists[15],x)
x = tf.sin(x)
x = self.sigma_dense(x)
z = tf.reshape(x,[self.batch_size,im_size,int(im_size/2)+1])
z = tf.reshape(x,[-1,im_size,int(im_size/2)+1,2])
a,b = tf.split(z,2,axis=-1)
a = tf.squeeze(a)
b = tf.squeeze(b)
fourie = tf.complex(a,b)
fourie = tf.complex(tf.cast(a,tf.float32),tf.cast(b,tf.float32))
return tf.squeeze(fourie)
......@@ -188,22 +220,32 @@ class Ray_maker(Layer):
class Noise_Maker(Model):
def __init__(self,layers,fromRGB,final_layer,shape):
super(ED_Maker,self).__init__()
super(Noise_Maker,self).__init__()
self.img_size_to_layer = {2:8, 4:7, 8:6, 16:5, 32:4, 64:3, 128:2, 256:1, 512:0}
self.fromRGB = fromRGB
self.final_layer = final_layer
self.start = self.img_size_to_layer[shape]
self.lays = layers
def call(self,input,alpha):
x = self.fromRGB[start](input)
for i, layer in enumerate(self.layers[start:]):
x = self.fromRGB[ 0 ](input)
lays = self.lays[:self.start ]
for i, layer in enumerate(lays):
if i == 1:
x = alpha * x + (1 - alpha) * self.fromRGB[start+1](tf.image.resize(input, [int(tf.shape(x)[1])*2,int(tf.shape(x)[1])*2], mode='nearest'))
x = alpha * x + (1 - alpha) * self.fromRGB[i+1](tf.image.resize(input, [int(tf.shape(x)[1]),int(tf.shape(x)[1])], method='nearest'))
x = layer(x)
out = self.final_layer(x)
x = self.final_layer(x).reshape(x.shape[0], 1)
return x
\ No newline at end of file
return out
......@@ -24,15 +24,14 @@ class mrc_loader:
self.ctf_size = 129
self.load = np.load(join(self.df_keys['tmp'],'ctf_params.npy'))
def __premake_image__(self,ins):
def __premake_image__(self,ins,imshape):
image = tf.io.decode_raw(ins,tf.float32)
image = tf.reshape(image,[self.df_keys['size'],self.df_keys['size'],1])
image = tf.image.resize(image,tf.constant([self.df_keys['resize'],self.df_keys['resize']],dtype=tf.int32))
image = tf.image.per_image_standardization(image)
image = image-tf.reduce_min(image)
image = tf.image.resize(image,[imshape,imshape])
return image/tf.reduce_max(image)
def __apply_ctf__(self,parameters):
......@@ -43,8 +42,11 @@ class mrc_loader:
alphas = tf.data.Dataset.from_tensor_slices(alphas)
index = tf.data.Dataset.from_tensor_slices(index)
im_shape = tf.data.Dataset.from_tensor_slices(im_shape)
data = tf.data.FixedLengthRecordDataset(self.df_keys['mrc_paths'],self.df_keys['bpr'],num_parallel_reads=self.df_keys['num_cpus'], header_bytes=1024).map(self.__premake_image__,self.df_keys['num_cpus']).repeat()
t+= [data]
data = tf.data.FixedLengthRecordDataset(self.df_keys['mrc_paths'],self.df_keys['bpr'],num_parallel_reads=self.df_keys['num_cpus'], header_bytes=1024)
d = tf.data.Dataset.zip((data,im_shape)).map(self.__premake_image__,self.df_keys['num_cpus']).repeat()
t+= [d]
if self.df_keys['ctf']:
......@@ -52,7 +54,7 @@ class mrc_loader:
t += [ctf_params]
f = tf.data.Dataset.zip(tuple(t)).prefetch(self.df_keys['batch_size']).batch(3*self.df_keys['batch_size']*self.df_keys['num_gpus'])
return tf.data.Dataset.zip(( f,alphas,index,im_shape))
return tf.data.Dataset.zip(( f,alphas,index)).batch(1)
def pred_generate(self):
......
import tensorflow as tf
from tensorflow.keras import mixed_precision
from models_sortem import Ray_maker,ResidualConvBlock,AdapterBlock,ED_Maker
from utils_sortem import Discriminator_Loss,Generator_Loss,Encoder_Loss,DiversityLoss,Eval_CTF
from models_sortem import Ray_maker,ResidualConvBlock,AdapterBlock,ED_Maker,Noise_Maker,ResidualUpBlock,UpAdapterBlock
from utils_sortem import Discriminator_Loss,Generator_Loss,Encoder_Loss,DiversityLoss,Eval_CTF,Make_Grids,Poisson_Measure
from os.path import join
import matplotlib.pyplot as plt
import numpy as np
......@@ -45,101 +45,114 @@ class Trainer:
mixed_precision.set_global_policy('mixed_float16')
else:
self.dtype_enc = tf.float32
fin = Conv2D(2*128+6,2,strides=2)
fin_encode = Conv2D(2*10+6,2,strides=2,padding='SAME')
fin_decode = Conv2D(1,2,strides=2,padding='SAME')
fin_noise = Conv2D(1,1,padding='SAME')
self.make_grids = Make_Grids()
self.poisson_noise = Poisson_Measure()
self.noise = [
ResidualConvBlock(16, 32, downsample=True), # 512x512 -> 256x256
ResidualConvBlock(32, 64, downsample=True), # 256x256 -> 128x128
ResidualConvBlock(64, 128, downsample=True), # 128x128 -> 64x64
ResidualConvBlock(128, 256, downsample=True), # 64x64 -> 32x32
ResidualConvBlock(256, 400, downsample=True), # 32x32 -> 16x16
ResidualConvBlock(400, 400, downsample=True), # 16x16 -> 8x8
ResidualConvBlock(400, 400, downsample=True), # 8x8 -> 4x4
ResidualConvBlock(400, 400, downsample=True), # 4x4 -> 2x2
]
ResidualUpBlock(400, 400, downsample=True), # 16x16 -> 8x8
ResidualUpBlock(400, 400, downsample=True), # 8x8 -> 4x4
ResidualUpBlock(400, 400, downsample=True), # 4x4 -> 2x2
ResidualUpBlock(256, 400, downsample=True), # 32x32 -> 16x16
ResidualUpBlock(128, 256, downsample=True), # 64x64 -> 32x32
ResidualUpBlock(64, 128, downsample=True), # 128x128 -> 64x64
ResidualUpBlock(32, 64, downsample=True), # 256x256 -> 128x128
ResidualUpBlock(16, 32, downsample=True), # 512x512 -> 256x256
]
self.noise_adapter = [
AdapterBlock(16),
AdapterBlock(32),
AdapterBlock(64),
AdapterBlock(128),
AdapterBlock(256),
AdapterBlock(400),
AdapterBlock(400),
AdapterBlock(400),
AdapterBlock(400)
]
UpAdapterBlock(400),
UpAdapterBlock(400),
UpAdapterBlock(400),
UpAdapterBlock(400),
UpAdapterBlock(256),
UpAdapterBlock(128),
UpAdapterBlock(64),
UpAdapterBlock(32),
UpAdapterBlock(16),
]
self.encoder = [
ResidualConvBlock(16, 32, downsample=True), # 512x512 -> 256x256
ResidualConvBlock(32, 64, downsample=True), # 256x256 -> 128x128
ResidualConvBlock(64, 128, downsample=True), # 128x128 -> 64x64
ResidualConvBlock(128, 256, downsample=True), # 64x64 -> 32x32
ResidualConvBlock(256, 400, downsample=True), # 32x32 -> 16x16
ResidualConvBlock(400, 400, downsample=True), # 16x16 -> 8x8
ResidualConvBlock(400, 400, downsample=True), # 8x8 -> 4x4
ResidualConvBlock(400, 400, downsample=True), # 4x4 -> 2x2
ResidualConvBlock(400, 400, downsample=True), # 8x8 -> 4x4
ResidualConvBlock(400, 400, downsample=True), # 16x16 -> 8x8
ResidualConvBlock(256, 400, downsample=True), # 32x32 -> 16x16
ResidualConvBlock(128, 256, downsample=True), # 64x64 -> 32x32
ResidualConvBlock(64, 128, downsample=True), # 128x128 -> 64x64
ResidualConvBlock(32, 64, downsample=True), # 256x256 -> 128x128
ResidualConvBlock(16, 32, downsample=True), # 512x512 -> 256x256s
]
self.encoder_adapter = [
AdapterBlock(16),
AdapterBlock(32),
AdapterBlock(64),
AdapterBlock(128),
AdapterBlock(256),
self.encoder_adapter = [
AdapterBlock(400),
AdapterBlock(400),
AdapterBlock(400),
AdapterBlock(400),
AdapterBlock(400)
AdapterBlock(256),
AdapterBlock(128),
AdapterBlock(64),
AdapterBlock(32),
AdapterBlock(16),
]
fin = Conv2D(1,2,strides=2)
self.discriminator = [
ResidualConvBlock(16, 32, downsample=True), # 512x512 -> 256x256
ResidualConvBlock(32, 64, downsample=True), # 256x256 -> 128x128
ResidualConvBlock(64, 128, downsample=True), # 128x128 -> 64x64
ResidualConvBlock(128, 256, downsample=True), # 64x64 -> 32x32
ResidualConvBlock(256, 400, downsample=True), # 32x32 -> 16x16
ResidualConvBlock(400, 400, downsample=True), # 16x16 -> 8x8
ResidualConvBlock(400, 400, downsample=True), # 8x8 -> 4x4
ResidualConvBlock(400, 400, downsample=True), # 4x4 -> 2x2
ResidualConvBlock(400, 400, downsample=True), # 8x8 -> 4x4
ResidualConvBlock(400, 400, downsample=True), # 16x16 -> 8x8
ResidualConvBlock(256, 400, downsample=True), # 32x32 -> 16x16
ResidualConvBlock(128, 256, downsample=True), # 64x64 -> 32x32
ResidualConvBlock(64, 128, downsample=True), # 128x128 -> 64x64
ResidualConvBlock(32, 64, downsample=True), # 256x256 -> 128x128
ResidualConvBlock(16, 32, downsample=True), # 512x512 -> 256x256s
]
self.discriminator_adapter = [
AdapterBlock(16),
AdapterBlock(32),
AdapterBlock(64),
AdapterBlock(128),
AdapterBlock(256),
AdapterBlock(400),
AdapterBlock(400),
AdapterBlock(400),
AdapterBlock(400)
AdapterBlock(400),
AdapterBlock(256),
AdapterBlock(128),
AdapterBlock(64),
AdapterBlock(32),
AdapterBlock(16),
]
self.Encoder = [ED_Maker(layers=self.encoder,fromRGB=self.encoder_adapter,final_layer=fin,shape=32)]
self.Discriminator = [ED_Maker(layers=self.discriminator,fromRGB=self.discriminator_adapter,final_layer=fin,shape=32)]
self.ctf = [Eval_CTF(args['kvolts'],args['sphe_ab'],args['amp_contrast'],size=32)]
self.Encoder = [ED_Maker(layers=self.encoder,fromRGB=self.encoder_adapter,final_layer=fin_encode,shape=32)]
self.Discriminator = [ED_Maker(layers=self.discriminator,fromRGB=self.discriminator_adapter,final_layer=fin_decode,shape=32)]
self.Noise = [Noise_Maker(layers=self.noise ,fromRGB=self.noise_adapter,final_layer=fin,shape=32)]
if args['ctf']:
self.ctf_eval = Eval_CTF(args['kvolts'],args['sphe_ab'],args['amp_contrast'])
if args['resize'] == 64:
self.Encoder += [ED_Maker(layers=self.encoder,fromRGB=self.encoder_adapter,final_layer=fin,shape=64)]
self.Discriminator += [ED_Maker(layers=self.discriminator,fromRGB=self.discriminator_adapter,final_layer=fin,shape=64)]
self.ctf += [Eval_CTF(args['kvolts'],args['sphe_ab'],args['amp_contrast'],size=64)]
self.Encoder += [ED_Maker(layers=self.encoder,fromRGB=self.encoder_adapter,final_layer=fin_encode,shape=64)]
self.Discriminator += [ED_Maker(layers=self.discriminator,fromRGB=self.discriminator_adapter,final_layer=fin_decode,shape=64)]
self.Noise += [Noise_Maker(layers=self.noise,fromRGB=self.noise_adapter,final_layer=fin,shape=64)]
if args['resize'] == 128:
self.Encoder += [ED_Maker(layers=self.encoder,fromRGB=self.encoder_adapter,final_layer=fin,shape=128)]
self.Discriminator += [ED_Maker(layers=self.discriminator,fromRGB=self.discriminator_adapter,final_layer=fin,shape=128)]
self.ctf += [Eval_CTF(args['kvolts'],args['sphe_ab'],args['amp_contrast'],size=128)]
self.Encoder += [ED_Maker(layers=self.encoder,fromRGB=self.encoder_adapter,final_layer=fin_encode,shape=128)]
self.Discriminator += [ED_Maker(layers=self.discriminator,fromRGB=self.discriminator_adapter,final_layer=fin_decode,shape=128)]
self.Noise += [Noise_Maker(layers=self.noise,fromRGB=self.noise_adapter,final_layer=fin,shape=128)]
if args['resize'] == 256:
self.Encoder += [ED_Maker(layers=self.encoder,fromRGB=self.encoder_adapter,final_layer=fin,shape=256)]
self.Discriminator += [ED_Maker(layers=self.discriminator,fromRGB=self.discriminator_adapter,final_layer=fin,shape=256)]
self.ctf += [Eval_CTF(args['kvolts'],args['sphe_ab'],args['amp_contrast'],size=256)]
self.Encoder += [ED_Maker(layers=self.encoder,fromRGB=self.encoder_adapter,final_layer=fin_encode,shape=256)]
self.Discriminator += [ED_Maker(layers=self.discriminator,fromRGB=self.discriminator_adapter,final_layer=fin_decode,shape=256)]
self.Noise += [Noise_Maker(layers=self.noise,fromRGB=self.noise_adapter,final_layer=fin,shape=256)]
if args['resize'] == 512:
self.Encoder += [ED_Maker(layers=self.encoder,fromRGB=self.encoder_adapter,final_layer=fin,shape=512)]
self.Discriminator += [ED_Maker(layers=self.discriminator,fromRGB=self.discriminator_adapter,final_layer=fin,shape=512)]
self.ctf += [Eval_CTF(args['kvolts'],args['sphe_ab'],args['amp_contrast'],size=512)]
self.Discriminator += [ED_Maker(layers=self.discriminator,fromRGB=self.discriminator_adapter,final_layer=fin_decode,shape=512)]
self.Noise += [Noise_Maker(layers=self.noise,fromRGB=self.noise_adapter,final_layer=fin,shape=512)]
self.current_gen_loss = 0
self.previous_gen_loss = 10**6
self.gen_op = tf.keras.optimizers.Adam(learning_rate=args['lr_g']) # optimizer for GAN step # (learning_rate=self.lr_linear_decay_generator
......@@ -156,9 +169,10 @@ class Trainer:
'Discriminator': self.Discriminator,
'Encoder': self.Encoder,
'Generator': Ray_maker(),
'Noise': self.Noise,
'gen_op': self.gen_op,
'decoder_op':self.decoder_op,
'decoder_op':self.encoder_op,
'encoder_op':self.encoder_op,
}
self.step_variable = self.vd['steps'].numpy()
......@@ -186,8 +200,6 @@ class Trainer:
if args['noise']:
self.vd['noise'] = ray_maker()
with self.args['strategy'].scope():
pass
......@@ -361,33 +373,60 @@ class Trainer:
self.angels = {'angels_1': [],'angels_2': [],'angels_3':[],'translations_1': [],'translations_2':[]}
@tf.function
def __Encoder_Step__(self,image,ctf,discriminator,encoder,noise):
# @tf.function
def __Encoder_Step__(self,image,ctf,alpha,discriminator,encoder,noise):
out_parameters = {}
with tf.GradientTape() as tape:
feature_vector = encoder(image)
rotation= feature_vector[0:6]
mean = feature_vector[6:128+6]
variance = feature_vector[128+6:]
fake_image = self.vd['Generator'](vector,fourie_grid)
feature_vector,_ = encoder(image,tf.cast(alpha,self.dtype_enc))
rotation= feature_vector[:,-3:]
mean = feature_vector[:,:10]
variance = feature_vector[:,10:20]
shape = tf.shape(image)[1]
fourie_grid = self.make_grids(rotation,shape)
ep = tf.random.normal(shape=tf.shape(mean),dtype=self.dtype_enc)
tmp = tf.sqrt(tf.exp(variance))*ep
z_tilde = tf.add(mean, tmp)
z_tilde = tf.reshape(z_tilde,[-1,10])
fake_image = self.vd['Generator'](shape,z_tilde,fourie_grid)
if self.args['ctf']:
fake_image = tf.signal.irrft2d(fake_image*self.eval_ctf(ctf))
fim = fake_image*self.ctf_eval(ctf,shape)
fake_image = tf.signal.irfft2d(fim)
else:
fake_image = tf.signal.irrft2d(fake_image)
if self.args['noise']:
fake_image = noise(fake_image)
real_features,d_real = discriminator(image)
fake_features,d_fake = discriminator(fake_image)
loss = self.encoder_loss_functions(real_features,[mean,variance,fake_features])
gen_var = self.vd['Encoder'].get_trainable_variables
self.compute_op(loss,gen_var,tape)
sigma = noise(tf.random.normal([self.args['batch_size'],4,4,4]),tf.cast(alpha,self.dtype_enc))
fake_image = self.poisson_noise(fake_image,sigma)