Commit 6f250433 authored by Jonathan Juhl's avatar Jonathan Juhl
Browse files

alpha version

parent 9351160d
......@@ -114,7 +114,7 @@ class control_flow:
if counter == 1:
V = float(row.split()[voltage])
electron_volts = (1.23*10**3)/np.sqrt(V*(V*10**(-7)*1.96+1))
abberation_d = float(row.split()[abberation])
amp_contrast_d = float(row.split()[amp_contrast])
......@@ -125,7 +125,7 @@ class control_flow:
np.save(join(self.particle_stack_dir,'depth.npy'),f)
np.save(join(self.particle_stack_dir,'names.npy'),names)
np.save(join(self.particle_stack_dir,'electron_volts.npy'),electron_volts)
np.save(join(self.particle_stack_dir,'electron_volts.npy'),V)
np.save(join(self.particle_stack_dir,'spherical_abberation.npy'),abberation_d)
np.save(join(self.particle_stack_dir,'amplitude_contrast.npy'),amp_contrast_d)
np.save(join(self.particle_stack_dir,'ctf_params.npy'),np.asarray(ctf_params))
......@@ -166,7 +166,7 @@ class control_flow:
final_labels = np.load(join(self.refined,'final_labels.npy'))
self.write_star_file(star_files,final_labels)
def add_params(self,parameter_file_path,current_image,binary,num_particles,width):
with open(join(parameter_file_path,'parameters.csv'), 'r', newline='') as file:
writer = csv.reader(file, delimiter = '\t')
parameters = list(writer)[0]
......
......@@ -5,25 +5,18 @@ matplotlib.use('Agg')
import os
from super_clas_sortem import super_class
from time import time
from tensorflow.keras.models import Model
from os.path import join,isfile
from os import listdir
import matplotlib.colors as mcolors
from tensorflow.keras.utils import plot_model
from tensorflow.keras import backend as K
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from scipy.optimize import linear_sum_assignment
import csv, os
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score,accuracy_score,adjusted_rand_score
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import normalized_mutual_info_score, recall_score,accuracy_score
from random import sample
from functools import partial
import matplotlib.pyplot as plt
plt.switch_backend('agg')
from tensorflow.keras.layers import Flatten
from utils_sortem import draw_from_distribution,translate,mm_models,sample_params,apply_ctf,z_maker,project_a_model,loss_disc,loss_gen,loss_encode,loss_latent
from utils_sortem import draw_from_distribution,translate,mm_models,sample_params,apply_ctf,loss_disc,loss_gen,loss_encode,loss_latent
from models import Discriminator_AdaIN_res128,Generator_AdaIN_res128,Generator_AdaIN_Noise
class DynAE(super_class):
......@@ -32,11 +25,11 @@ class DynAE(super_class):
,work_dir
)
self.paths = mrc_paths
self.large_rescale = 128
self.ctf_params = np.load(join(self.particle_stack_dir,'ctf_params.npy'))
self.P = Generator_AdaIN_Noise()
self.G = Generator_AdaIN_res128()
self.D = Discriminator_AdaIN_res128(4**3)
self.D = Discriminator_AdaIN_res128(self.num_parts)
num_steps = (self.max_particles/(self.batch_size*self.num_gpus))*self.epochs
self.steps = num_steps
self.g_opt = self.optimizer(self.steps)
......@@ -45,26 +38,6 @@ class DynAE(super_class):
if self.check_dir(self.unrefined,'final_labals.npy'): # run the program if class not computed
self.train_aci_ae()
def load_from_checkpoint(self,prefix):
d_iter = listdir(self.models)
length = len(list(filter(lambda x: prefix and '.index' in x,d_iter)))
# their is no checkpoints, will restart the training.
num_steps = (self.max_particles/(self.batch_size*self.num_gpus))*self.epochs-length*self.save_model
if length != 0:
try:
newest_checkpoint = max(list(filter(lambda x: prefix in x,d_iter)))
except: pass
else:
newest_checkpoint = None
return num_steps,newest_checkpoint
@tf.function
def cluster(core_means,images):
......@@ -84,59 +57,77 @@ class DynAE(super_class):
@tf.function
def train_g(self,image_0,take_components, ctf_pr_count,t_x_pr_count,t_y_pr_count,psi_count,rho_count,inplane_count,lambdas,spher_abb,ac):
angels = tf.cast(tf.stack([psi_count,rho_count],axis=1) ,self.precision)
image_0 = tf.image.resize(image_0,[128,128])
def train_g(self,image_0,take_components, ctf_pr_count,t_x_pr_count,t_y_pr_count,inplane_count,lambdas,spher_abb,ac):
add_images = []
catagorial = tf.cast(tf.one_hot(take_components,self.num_parts),self.precision)
with tf.GradientTape() as t:
catagorial = tf.random.normal(shape=[self.sample_size,4,4,4,1])
fake_projections,z = self.G(catagorial,rho_count,psi_count,inplane_count)
if self.include_ctf:
fake_projections = apply_ctf(fake_projections,ctf_pr_count,lambdas,spher_abb,ac)
if self.include_noise:
noise = tf.squeeze(self.poisson_noise(fake_model,z))
generated_image = tf.expand_dims(fake_model,axis=-1) + sigma*tf.sqrt(tf.abs(ctf_applied_image))*tf.random.normal(shape=[self.sample_size,128,128,1])
h5_sig_f, h5, predict_z_f, d_h1_f, d_h2_f, d_h3_f, d_h4_f,predict_angels_f = self.D(fake_projections)
fake_projections = self.G(catagorial,inplane_count)
add_images.append(fake_projections[0])
if self.ctf:
fake_projections = tf.expand_dims(apply_ctf(fake_projections,ctf_pr_count,lambdas,spher_abb,ac),axis=-1)
add_images.append(fake_projections[0])
else:
add_images.append(None)
if self.noise:
z = tf.random.normal(shape=[self.batch_size,4,4,1])
noise = self.P(z)
#fake_projections = fake_projections + noise*tf.sqrt(tf.abs(fake_projections))*tf.random.normal(shape=[self.batch_size,self.large_rescale,self.large_rescale,1])
add_images.append(fake_projections[0])
else:
add_images.append(None)
h5_sig_f, h5, predict_z_f, d_h1_f, d_h2_f, d_h3_f, d_h4_f = self.D(fake_projections)
z = Flatten()(z)
loss_tot = 0
if self.use_z:
loss_tot += loss_latent(z,predict_z_f)
loss_tot += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(h5_sig_f, tf.ones_like(h5_sig_f)))+loss_latent(z,predict_z_f)
lz = loss_latent(predict_z_f,catagorial)
loss_tot += lz
lg = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(h5_sig_f, tf.ones_like(h5_sig_f)))
loss_tot += lg
grad = t.gradient(loss_tot, self.G.trainable_variables)
self.g_opt.apply_gradients(zip(grad, self.G.trainable_variables))
return loss_tot
return lz,lg,add_images
@tf.function
def train_d(self, x_real,take_components, ctf_pr_count,t_x_pr_count,t_y_pr_count,psi_count,rho_count,inplane_count,lambdas,spher_abb,ac):
angels = tf.cast(tf.stack([psi_count,rho_count],axis=1) ,self.precision)
def train_d(self, x_real,take_components,ctf_pr_count,t_x_pr_count,t_y_pr_count,inplane_count,lambdas,spher_abb,ac):
add_images = []
catagorial = tf.cast(tf.one_hot(take_components,self.num_parts),self.precision)
with tf.GradientTape() as t:
catagorial = tf.random.normal(shape=[self.sample_size,4,4,4,1])
fake_projections,z = self.G(catagorial,rho_count,psi_count,inplane_count)
if self.include_ctf:
fake_projections = apply_ctf(fake_projections,ctf_pr_count,lambdas,spher_abb,ac)
if self.include_noise:
noise = tf.squeeze(self.poisson_noise(fake_model,z))
generated_image = tf.expand_dims(fake_model,axis=-1) + sigma*tf.sqrt(tf.abs(ctf_applied_image))*tf.random.normal(shape=[self.sample_size,128,128,1])
h5_sig_f, h5, predict_z_f, d_h1_f, d_h2_f, d_h3_f, d_h4_f,predict_angels_f = self.D(fake_projections)
h5_sig_r, h5_r, predict_z, d_h1_r, d_h2_r, d_h3_r, d_h4_r,predict_angels_r = self.D(x_real)
z = Flatten()(z)
l_n = loss_encode(d_h1_r,d_h1_f,d_h2_r,d_h2_f,d_h3_r,d_h3_f,d_h4_r,d_h4_f)
loss_tot = loss_disc(h5_sig_r,h5_sig_f,predict_z_f,z,predict_angels_f,angels,use_z=self.use_z,use_angels=self.use_angels) +l_n
grad = t.gradient(loss_tot, self.D.trainable_variables)
fake_projections = self.G(catagorial,inplane_count)
add_images.append(fake_projections[0])
if self.ctf:
fake_projections = tf.expand_dims(apply_ctf(fake_projections,ctf_pr_count,lambdas,spher_abb,ac),axis=-1)
add_images.append(fake_projections[0])
else:
add_images.append(None)
if self.noise:
z = tf.random.normal(shape=[self.batch_size,4,4,1])
sigma = self.P(z)
fake_projections = fake_projections+ sigma*tf.sqrt(tf.abs(fake_projections))*tf.random.normal(shape=[self.batch_size,128,128,1])
add_images.append(fake_projections[0])
else:
add_images.append(None)
h5_sig_f, h5, predict_z_f, d_h1_f, d_h2_f, d_h3_f, d_h4_f = self.D(fake_projections)
h5_sig_r, h5_r, predict_z, d_h1_r, d_h2_r, d_h3_r, d_h4_r = self.D(x_real)
ln = loss_encode(d_h1_r,d_h1_f,d_h2_r,d_h2_f,d_h3_r,d_h3_f,d_h4_r,d_h4_f)
ld,lz = loss_disc(h5_sig_r,h5_sig_f,predict_z_f,catagorial)
loss_total = ln+ld
grad = t.gradient(loss_total, self.D.trainable_variables)
self.d_opt.apply_gradients(zip(grad, self.D.trainable_variables))
return loss_tot,fake_projections
return ln,ld,lz,add_images
......@@ -148,30 +139,24 @@ class DynAE(super_class):
logfile = open(join( self.results,'%s_log.csv') %training_file_name, 'w')
if self.verbose:
logwriter = csv.DictWriter(logfile, fieldnames=['iter','loss discriminator', 'loss generator', 'loss latent', 'lost angel','NMI','accuracy','recall score'])
logwriter.writeheader() # write to log file if in training mode /results/.
logwriter = csv.DictWriter(logfile, fieldnames=['iter','loss discriminator', 'loss generator', 'loss latent', 'lost angel','NMI','accuracy','recall score'])
logwriter.writeheader() # write to log file if in training mode /results/.
else:
strategy,dist_it = self.generator(self.paths,byte=self.bytes_pr_record,scale=self.large_rescale,batch_size=self.sample_size) # the image loader, for training.
logwriter = csv.DictWriter(logfile, fieldnames=['iter','loss discriminator', 'loss generator', 'loss latent', 'lost angel'])
logwriter.writeheader() # write to log file if in training mode /results/ but no verbose data.
strategy,dist_it = self.generator(self.paths,byte=self.bytes_pr_record,scale=self.large_rescale,batch_size=self.batch_size) # the image loader, for training.
t0 = time()
l_list = []
if self.verbose:
image,y = next(dist_pred_it)
image = image.numpy()
y = y.numpy()
else:
image = next(dist_pred_it)
image = image.numpy()
dist_it = iter(dist_it)
lambdas = np.load(join(self.particle_stack_dir,'electron_volts.npy'))
spher_abb = np.load(join(self.particle_stack_dir,'spherical_abberation.npy'))
ac = np.load(join(self.particle_stack_dir,'amplitude_contrast.npy'))
ctf_params = np.load(join(self.particle_stack_dir,'ctf_params.npy'))
for i in range(self.steps):
ite = 0
for i in range(int(self.steps)):
......@@ -179,56 +164,70 @@ class DynAE(super_class):
print('training time: ', time() - t0,"step: %i of %i " %(i,self.steps))
for i in range(2):
takes = np.random.choice(np.arange(self.num_parts),self.sample_size)
ctf_pr_count,t_x_pr_count,t_y_pr_count,psi_count,rho_count,inplane_count = draw_from_distribution(self.large_rescale,ctf_params,self.sample_size)
image_0 = next(dist_it)
image_0 = tf.cast(image_0,self.precision)
loss_d,loss_q,loss_a,x_original,x_ctf,x_noise = strategy.run(self.train_d,(image_0,takes, ctf_pr_count,t_x_pr_count,t_y_pr_count,psi_count,rho_count,inplane_count,lambdas,spher_abb,ac)) # loss_1,
takes = np.random.choice(np.arange(self.num_parts),self.batch_size)
ctf_pr_count,t_x_pr_count,t_y_pr_count,inplane_count = draw_from_distribution(self.large_rescale,ctf_params,self.batch_size)
if self.verbose:
image,y = next(dist_it)
image = image.numpy()
y = y.numpy()
else:
image = next(dist_it)
image = image.numpy()
image = tf.cast(image,self.precision)
loss_d,loss_style,loss_z,images = strategy.run(self.train_d,(image,takes, ctf_pr_count,t_x_pr_count,t_y_pr_count,inplane_count,lambdas,spher_abb,ac)) # loss_1,
image_0 = next(dist_it)
takes = np.random.choice(np.arange(self.num_parts),self.sample_size)
ctf_pr_count,t_x_pr_count,t_y_pr_count,psi_count,rho_count,inplane_count = draw_from_distribution(self.large_rescale,ctf_params,self.sample_size)
loss_g,loss_q,loss_a,x_original,x_ctf,x_noise = strategy.run(self.train_g,(image_0,takes, ctf_pr_count,t_x_pr_count,t_y_pr_count,psi_count,rho_count,inplane_count,lambdas,spher_abb,ac ))
if self.verbose:
image,y = next(dist_it)
image = image.numpy()
y = y.numpy()
else:
image = next(dist_it)
image = image.numpy()
takes = np.random.choice(np.arange(self.num_parts),self.batch_size)
ctf_pr_count,t_x_pr_count,t_y_pr_count,inplane_count = draw_from_distribution(self.large_rescale,ctf_params,self.batch_size)
loss_g,loss_q,images = strategy.run(self.train_g,(image,takes, ctf_pr_count,t_x_pr_count,t_y_pr_count,inplane_count,lambdas,spher_abb,ac ))
l_list.append([loss_d,loss_g,loss_q,loss_a])
take = []
if ite % self.validate_interval == 0: ¨
if ite % self.validate_interval == 0:
if self.verbose:
nmi = NMI(y,predicted)
acc = ACC(y,predicted)
nmi = normalized_mutual_info_score(y,predicted)
acc = accuracy_score(y,predicted)
rs = recall_score(y,predicted)
logfile.write(l_list+[nmi,acc,rs])
self.plot_gan(l_list,images[0],images[1],images[2],image[0],nmi,acc,rs) # plotting loss, average distance between all images
logfile.write('%f %f %f %f %f %f' %(loss_d,loss_style,loss_z,nmi,acc,rs))
else:
logfile.write(l_list)
logfile.write('%f %f %f ' %(loss_d,loss_style,loss_z))
self.plot_gan(l_list,x_original,x_ctf,x_noise,self.validate_interval,training_file_name) # plotting loss, average distance between all images
self.P.save_weights(join(self.models,'%s_model_poisson' %training_file_name))
self.G.save_weights(join(self.models,'%s_model_generator' %training_file_name))
self.D.save_weights(join(self.models,'%s_model_discriminator' %training_file_name))
self.plot_gan(l_list,images[0],images[1],images[2],image[0]) # plotting loss, average distance between all images
ite +=1
if ite == int(self.steps):
break
logfile.close()
self.class_encode.save_weights(join(self.models,'%s_model' %training_file_name))
def train_aci_ae(self, maxiter=120e3, batch_size=256):
def train_aci_ae(self):
self.opt = self.optimizer(self.steps)
......@@ -240,7 +239,7 @@ class DynAE(super_class):
if self.check_dir(self.particle_stack_dir,'centroids'):
matrix_mean,matrix_mean_bias,matrix_variance,matrix_variance_bias = self.Cluster_Layer.get_weights()
strategy,dist_it = self.generator(self.paths,byte=self.bytes_pr_record,scale=self.large_rescale,batch_size=self.sample_size)
strategy,dist_it = self.generator(self.paths,byte=self.bytes_pr_record,scale=self.large_rescale,batch_size=self.batch_size)
for i in range(self.steps):
image,y = next(dist_it)
......@@ -256,7 +255,7 @@ class DynAE(super_class):
ilands = np.dot(out_matrix,out_matrix.T) > 0.999
components_matrix = connected_components(ilands)
strategy,dist_it = self.generator(self.paths,byte=self.bytes_pr_record,scale=self.large_rescale,batch_size=self.sample_size)
strategy,dist_it = self.generator(self.paths,byte=self.bytes_pr_record,scale=self.large_rescale,batch_size=self.batch_size)
label_list = []
for i in range(int(self.depth/batch_size*self.num_gpus)):
......@@ -268,7 +267,7 @@ class DynAE(super_class):
label_list += np.take_along_axis(components_matrix,matrix,axis=1).tolist()
out = np.asarray(label_list)
np.save(join(self.refined,'sub_labels.npy'),out)
np.save(join(self.refined,'labels.npy'),out)
tmp_labs = np.unique(components_matrix)
......@@ -324,18 +323,23 @@ class DynAE(super_class):
plt.savefig(join(self.results,'protein projection angels'))
plt.clf()
def plot_gan(self,loss,original,ctf,noise,NMI,ACC,Recall):
zeros = np.zero(shape=[original.shape[1],original.shape[2]])
image = np.concatenate([np.concatenate([zeros,original],axis=0),np.concatenate([zeros,original],axis=0)],axis=1)
def plot_gan(self,loss,original,ctf,noise,real,NMI=None,ACC=None,Recall=None):
image = np.concatenate([np.concatenate([np.squeeze(ctf),np.squeeze(original)],axis=0),np.concatenate([np.squeeze(real),np.squeeze(noise)],axis=0)],axis=1)
loss = np.asarray(loss)
if self.verbose:
fig, ((ax1,ax2),(ax3,ax4)) = plt.subplot((2,2))
fig, ((ax1,ax2),(ax3,ax4)) = plt.subplots((2,2))
else:
fig, (ax1,ax2)= plt.subplot((1,2))
fig, (ax1,ax2)= plt.subplots(2)
ax1.scatterplot(np.arange(loss.shape[0]),loss,line='-')
ax1.scatter(np.arange(loss.shape[0]),loss)
ax1.set_title('loss')
ax2.imshow(image)
ax2.set_title('generator images')
if self.verbose:
ax3.scatterplot(np.arange(NMI.shape[0]),np.concatenate([NMI,ACC],axis=1))
ax4.scatterplot(np.arange(Recall.shape[0]),Recall)
\ No newline at end of file
ax3.scatter(np.arange(NMI.shape[0]),np.concatenate([NMI,ACC],axis=1))
ax4.scatter(np.arange(Recall.shape[0]),Recall)
plt.savefig(join(self.results,'logging.png'))
\ No newline at end of file
......@@ -16,54 +16,34 @@ def main():
parser.add_argument('--num_cpus',type=int,default = 8,help='The maximum allowed cpus to use for preprocessing data and Kmeans clustering')
parser.add_argument('--star', type=str, nargs='+',
help='list of path to the star files, wild cards are accepted. The star file must refer to the .mrc files')
parser.add_argument('--ab', type=int,default=200,
parser.add_argument('--ab', type=int,default=50,
help='deep learning model training batch')
parser.add_argument('--pb', type=int,default=200,
help='deep learning model training batch')
parser.add_argument('--o', type=str,default='./results',
help='output directory')
parser.add_argument('--f16', type=str,default="False",
help='Apply Tensor core acceleration to training and inference, requires compute capability of 10.0 or higher.')
parser.add_argument('--mp', type=int,default=50*10**3,
help='max amount of particle to train pr. epoch')
parser.add_argument('--epochs', type=int,default=20,help='The number of epochs to iterate through the dataset, defined to have the size by the parameter --mp')
help='max amount of particle to train pr. epoch')
parser.add_argument('--vi', type=int,default=20,help='validation interval where statistics are printed out.')
parser.add_argument('--tr', type=str,default = 'False', help='Use a pretrained model for fast predictions. The overall accuracy will descrease')
parser.add_argument('--epochs', type=int,default=20,help='The number of epochs to iterate through the dataset, defined to have the size by the parameter --mp')
parser.add_argument('--verbose', type=str,default= 'False',help='se the performance of the model by including original class labels')
parser.add_argument('--log', type=str,default="False",help='log all possible values to file (loss, pca_components,NMI,Recall,false positives,false negatives.')
parser.add_argument('--num_parts',type=int,default=2,help='the number subparts in the refinement process. This is to improve accuracy for highly similar proteins')
parser.add_argument('--resize',type=int,default=256,help='Image to resize to for training and inference. If the image is less than 256 consider downsizing. Minimum size is 128')
parser.add_argument('--part_resize',type=int,default=128,help='When refining the components the image size is usually 128')
parser.add_argument('--early_stopping',type=float,default=0.001,help='when the model has converged. Stops early if max iter not reached.')
parser.add_argument('--dbscan',type=int,default=50000,help='The number of samples to use for dbscan. This scales a lot with memory. Default is 50000')
parser.add_argument('--num_parts',type=int,default=30,help='Number of gaussian components to use. (This is the maximum number)')
parser.add_argument('--lr',type=float,default=0.002,help='The learning rate of the model')
parser.add_argument('--vi',type=float,default=40,help='The validation interval')
parser.add_argument('--eps',type=float,default=0.01,help='The epsilon of the data.')
parser.add_argument('--nb',type=int,default=10,help='Nearest neighbours')
parser.add_argument('--pr',type=str,default='False',help='pretrain the signal model')
parser.add_argument('--mb',type=int,default=1,help='pretrain the signal model')
parser.add_argument('--ss',type=int,default=10,help='pretrain the signal model')
parser.add_argument('--rs',type=int,default=32,help='pretrain the signal model')
parser.add_argument('--val', type=str, nargs='+',help='Use a validation set which is different from the training set.')
parser.add_argument('--angels', type=str,default = 'True',help='do post training where you estimate the angular distribution')
parser.add_argument('--rab', type=int, default=5,help='The Denoiser to Train for the image denoising.')
parser.add_argument('--lam',type=int,default=0.2,help='The lambda weight value for the pretraining loss of the autoencoder')
parser.add_argument('--ctf', type=str,default = 'True',help='Use CTF parameters for model.')
parser.add_argument('--noise', type=str,default = 'True',help='Use the noise generator for model .')
args = parser.parse_args()
......@@ -88,27 +68,13 @@ def main():
batch_size = args.ab
pred_batch_size = args.pb
num_cpus = args.num_cpus
star = args.star
train_light = args.tr
epochs = args.epochs
output = args.o
half_precision = args.f16
verbose = args.log
max_particles = args.mp
num_clusters = args.num_parts
binary = 0
num_particles = 0
width = 0
s1 = ' '.join(gpu_list)
l = [args.ab,args.pb,args.num_parts,args.num_cpus,args.num_gpus,args.vi,' '.join(gpu_list),args.f16,args.tr,args.epochs,args.log,args.mp,args.eps,args.nb,' '.join(args.star),args.dbscan,args.lam,args.lr,args.val,args.mb,args.ss,args.rs,args.pr,'current_image',binary,num_particles,width]
l = [args.ab,args.pb,args.num_parts,args.num_cpus,args.num_gpus,args.vi,' '.join(gpu_list),args.f16,args.verbose,args.epochs,args.mp,' '.join(args.star),args.lr,args.angels,args.ctf,args.noise,'current_image',binary,num_particles,width]
if not isdir(args.o):
mkdir(args.o)
......
from super_clas_sortem import MaxBlurPooling2D,super_class
from super_clas_sortem import super_class
from tensorflow.keras import Model
import tensorflow as tf
from tensorflow.keras.layers import Layer, Flatten, Lambda, InputSpec, Input,LeakyReLU,Conv2DTranspose, Dense,Conv2D,GlobalAveragePooling2D,BatchNormalization,Activation,UpSampling2D,Conv3D,Conv3DTranspose,LeakyReLU,Dense,UpSampling3D,MaxPool3D,MaxPool2D,ReLU
from tensorflow.keras.layers import Flatten,LeakyReLU,Conv2DTranspose, Dense,Conv2D,UpSampling2D,Conv3D,Conv3DTranspose,LeakyReLU,Dense,UpSampling3D,ReLU
from utils_sortem import Spectral_norm,Instance_norm,transform_3D
class fit_layer(Model):
......@@ -26,20 +26,37 @@ class Cluster_Layer(Model):
def __init__(self):
super(Cluster_Layer,self).__init__()
self.dense_mean = Dense(4**3)
self.dense_var = Dense(4**3)
self.dense_var = Dense(1)
self.dense_mean_angel = Dense(2)
self.dense_var_angel = Dense(1)
def call(self,catagorial_variable):
catagorial_variable = Flatten()(catagorial_variable)
mean = self.dense_mean(catagorial_variable)
s =tf.shape(mean)
batch = s[0]
length = s[1]
epsilon = tf.random.normal(shape=[batch,length])
logvar = self.dense_var(catagorial_variable)
out = epsilon*logvar+mean
return out
mean_angel = self.dense_mean_angel(catagorial_variable)
logvar_angel = self.dense_var_angel(catagorial_variable)
epsilon_angel = tf.random.normal(shape=[batch,2])
out_angel = tf.nn.softsign(epsilon_angel*logvar_angel+mean_angel)
return out,out_angel
def AdaIn_3D(inputs,s1,b1):
b= tf.shape(inputs)[0]
w= tf.shape(inputs)[1]
......@@ -86,27 +103,27 @@ class Double_Dense(Model):
class Generator_AdaIN_Noise(Model):
def __init__(self,gf_dim=32):
def __init__(self,gf_dim=64):
self.gf_dim = gf_dim
super(Generator_AdaIN_Noise,self).__init__()
self.zmap_0 = Double_Dense( self.gf_dim * 32)
self.zmap_0 = Double_Dense( self.gf_dim * 8)
self.h0 = AdaIn_2D
self.h0_a = LeakyReLU()
self.h1= Conv2DTranspose( self.gf_dim * 32,3,strides=2,padding='SAME')
self.z_map_1= Double_Dense(self.gf_dim * 32)
self.h1= Conv2DTranspose( self.gf_dim * 8,3,strides=2,padding='SAME')
self.z_map_1= Double_Dense(self.gf_dim * 8)
self.h1_aI = AdaIn_2D
self.h1 = LeakyReLU()
self.h1_a = LeakyReLU()
self.h2 = Conv2DTranspose(self.gf_dim * 16,3,strides=2,padding='SAME')
self.h2_a = LeakyReLU()
self.z_map_2= Double_Dense(self.gf_dim * 16)
self.h2_aI = AdaIn_2D
self.h2 = LeakyReLU()
self.h2_a = LeakyReLU()
self.h5 = Conv2DTranspose(self.gf_dim*4,4,strides=2,padding='SAME')
self.z_map_4 = Double_Dense(self.gf_dim*4)
......@@ -124,9 +141,10 @@ class Generator_AdaIN_Noise(Model):