Commit 9f7bfaf7 authored by Jonathan Juhl's avatar Jonathan Juhl
Browse files

new_stuff

parent 9f6fc2c7
......@@ -2,244 +2,246 @@
import numpy as np
import os
from super_clas_sortem import super_class
from os.path import join,getsize,isdir,isfile,dirname,basename
from os import listdir,rename
import csv
#from fac_sortem import DynAE
from os import listdir,mkdir
import glob
from os.path import isabs
from fac_sortem import DynAE
from fac_sortem import GAN_NERF
import tensorflow as tf
class control_flow:
def __init__(self,parameter_file_path):
def __init__(self,args):
self.args = args
dir_path = os.path.dirname(os.path.realpath(__file__))
if not isinstance(self.args['star'], list) and isdir(star):
star_files = glob.glob(self.args['star'])
if star_files == []:
print("no star files in directory. You must point to atleast one star file to run Sortinator.")
exit()
star_files = [join(dir_path,join(star,i)) for i in star_files]
else:
super_class.__init__(self
,parameter_file_path
)
def get_star_file_parameters(star_files):
counter = []
f = 0
names = []
star_files = np.unique(star_files)
labels_list = []
count = 0
for z in star_files:
#print(z)
c = z[::-1].split('/',1)[1][::-1]
#c = '/emcc/au587640/cryosparc2_data/P139/'
with open(z, newline='') as csvfile:
self.args['star'] = [join(dir_path,i) for i in self.args['star']]
if not isfile(join(self.args['tmp'],'depth.npy')):
depth,mrc_paths,voltage = self.get_star_file_parameters()
else:
depth = np.load(join(self.args['tmp'],'depth.npy'))
mrc_paths = np.load(join(self.args['tmp'],'names.npy'))
length,bytes_pr_record = self.get_parameters(mrc_paths[0])
self.args['mrc_paths'] = mrc_paths
self.args['size'] = length
self.args['number_particles'] = depth
self.args['bpr'] = bytes_pr_record
if args['ctf']:
self.args['kvolts'] = tf.constant(np.load(join(self.args['tmp'],'electron_volts.npy')),tf.float32)
self.args['sphe_ab'] = tf.constant(np.load(join(self.args['tmp'],'spherical_abberation.npy')),tf.float32)
self.args['amp_contrast'] = tf.constant(np.load(join(self.args['tmp'],'amplitude_contrast.npy')),tf.float32)
GAN_NERF(self.args)
final_labels = np.load(join(self.args['results'],'labels.npy'))
self.write_star_file(star_files,final_labels)
# print(self.args );exit()
def get_star_file_parameters(self):
counter = []
f = 0
names = []
star_files = np.unique(self.args['star'])
labels_list = []
angular_list = []
count = 0
for z in star_files:
c = z[::-1].split('/',1)[1][::-1]
with open(z, newline='') as csvfile:
reader = list(csvfile)
header = list(filter(lambda x: '_rln' == x[0:4] or 'loop_' == x.strip(),reader))
header = [i.split()[0] for i in header]
heads = []
reader = list(csvfile)
header = list(filter(lambda x: '_rln' == x[0:4] or 'loop_' == x.strip(),reader))
header = [i.split()[0] for i in header]
heads = []
head = []
counter = 0
for i in header:
if 'loop_' == i:
counter+=1
head = []
counter = 0
for i in header:
if 'loop_' == i:
counter+=1
head = []
heads.append(head)
else:
head.append(i)
take_this = []
for i in heads:
if self.ctf:
if '_rlnVoltage' in i:
take_that = i
if '_rlnImageName' in i:
take_this = i
header = take_this
name = header.index('_rlnImageName')
if self.ctf:
voltage = take_that.index('_rlnVoltage')
defocusU = header.index('_rlnDefocusU')
defocusV = header.index('_rlnDefocusV')
defocusAngle = header.index('_rlnDefocusAngle')
abberation = take_that.index('_rlnSphericalAberration')
amp_contrast = take_that.index('_rlnAmplitudeContrast')
phase_shift = header.index('_rlnPhaseShift')
heads.append(head)
else:
head.append(i)
take_this = []
for i in heads:
if self.args['ctf']:
if '_rlnVoltage' in i:
take_that = i
if '_rlnImageName' in i:
take_this = i
header = take_this
name = header.index('_rlnImageName')
if self.args['ctf']:
voltage = take_that.index('_rlnVoltage')
defocusU = header.index('_rlnDefocusU')
defocusV = header.index('_rlnDefocusV')
defocusAngle = header.index('_rlnDefocusAngle')
abberation = take_that.index('_rlnSphericalAberration')
amp_contrast = take_that.index('_rlnAmplitudeContrast')
phase_shift = header.index('_rlnPhaseShift')
ctf_params = []
if self.args['verbose']:
try:
class_num = header.index('_rlnClassNumber')
except:
self.args['verbose'] = False
print("the --verbose true cannot be run as _rlnClassNumber is missing ")
if self.args['ang_error']:
try:
class_num = header.index('_rlnClassNumber')
angel_rot = header.index('_rlnAngleRot')
angel_tilt = header.index('_rlnAngleTilt')
angle_psi = header.index('_rlnAnglePsi')
origin_x = header.index('_rlnOriginX')
origin_y = header.index('_rlnOriginY')
ctf_params = []
if self.verbose:
try:
class_num = header.index('_rlnClassNumber')
except:
self.verbose = False
print("the --log true cannot be run as _rlnClassNumber i missing ")
for row in reader:
if self.ctf:
if len(take_that) == len(row.split()) and counter ==2:
abberation_d = float(row.split()[abberation])
amp_contrast_d = float(row.split()[amp_contrast])
V = float(row.split()[voltage])
electron_volts = (1.23*10**3)/np.sqrt(V*(V*10**(-7)*1.96+1))
counter = 0
if len(header)== len(row.split()):
f+=1
current_name = row.split()[name].split('@')[1]
if len(names) != 0:
if names[-1] != current_name:
names.append(join(c,current_name))
else:
names.append(join(c,current_name))
if self.verbose:
labels_list.append(int(row.split()[class_num]))
if counter == 1:
if self.verbose:
V = float(row.split()[voltage])
except:
self.args['ang_error'] = False
print("the --verbose true cannot be run as angular reconstruction is missing ")
for row in reader:
if self.args['ctf']:
if len(take_that) == len(row.split()) and counter ==2:
abberation_d = float(row.split()[abberation])
amp_contrast_d = float(row.split()[amp_contrast])
counter = 0
if self.ctf:
ctf_params.append([float(row.split()[phase_shift]),float(row.split()[defocusU]),float(row.split()[defocusV]),float(row.split()[defocusAngle])])
abberation_d = float(row.split()[abberation])
amp_contrast_d = float(row.split()[amp_contrast])
V = float(row.split()[voltage])
electron_volts = (1.23*10**3)/np.sqrt(V*(V*10**(-7)*1.96+1))
counter = 0
if len(header)== len(row.split()):
f+=1
current_name = row.split()[name].split('@')[1]
if len(names) != 0:
if names[-1] != current_name:
names.append(join(c,current_name))
else:
names.append(join(c,current_name))
if self.args['verbose']:
labels_list.append(int(row.split()[class_num]))
if self.args['ang_error']:
angular_list.append([float(row.split()[angel_rot]),float(row.split()[angel_tilt]),float(row.split()[angle_psi]),float(row.split()[origin_x]),float(row.split()[origin_y])])
current_id = row.split()[name].split('@')[0]
np.save(join(self.particle_stack_dir,'depth.npy'),f)
np.save(join(self.particle_stack_dir,'names.npy'),names)
if self.ctf:
np.save(join(self.particle_stack_dir,'electron_volts.npy'),V)
np.save(join(self.particle_stack_dir,'spherical_abberation.npy'),abberation_d)
np.save(join(self.particle_stack_dir,'amplitude_contrast.npy'),amp_contrast_d)
np.save(join(self.particle_stack_dir,'ctf_params.npy'),np.asarray(ctf_params))
if self.verbose:
np.save(join(self.particle_stack_dir,'labels.npy'),np.asarray(labels_list))
return f,np.unique(names)
dir_path = os.path.dirname(os.path.realpath(__file__))
if not isinstance(self.star, list) and isdir(star):
star_files = glob.glob(self.star)
if star_files == []:
print("no star files in directory. You must point to atleast one star file to run Sortinator.")
exit()
if counter == 1:
if self.args['ctf']:
V = float(row.split()[voltage])
abberation_d = float(row.split()[abberation])
amp_contrast_d = float(row.split()[amp_contrast])
counter = 0
if self.args['ctf']:
ctf_params.append([float(row.split()[phase_shift]),float(row.split()[defocusU]),float(row.split()[defocusV]),float(row.split()[defocusAngle])])
current_id = row.split()[name].split('@')[0]
np.save(join(self.args['tmp'],'depth.npy'),f)
np.save(join(self.args['tmp'],'names.npy'),names)
if self.args['ctf']:
np.save(join(self.args['tmp'],'electron_volts.npy'),V)
np.save(join(self.args['tmp'],'spherical_abberation.npy'),abberation_d)
np.save(join(self.args['tmp'],'amplitude_contrast.npy'),amp_contrast_d)
np.save(join(self.args['tmp'],'ctf_params.npy'),np.asarray(ctf_params))
if self.args['verbose']:
np.save(join(self.args['tmp'],'labels.npy'),np.asarray(labels_list))
if self.args['ang_error']:
np.save(join(self.args['tmp'],'angular_error.npy'),np.asarray(angular_list))
return f,np.unique(names),V
def write_star_file(self,star_files,labels):
names = []
for z in star_files:
with open(z, newline='') as csvfile:
star_files = [join(dir_path,join(star,i)) for i in star_files]
else:
star_files = [join(dir_path,i) for i in self.star]
reader = list(csvfile)
header = list(filter(lambda x: '_rln' in x,reader))
if not isfile(join(self.particle_stack_dir,'depth.npy')):
depth,mrc_paths = get_star_file_parameters(star_files)
else:
depth = np.load(join(self.particle_stack_dir,'depth.npy'))
mrc_paths = np.load(join(self.particle_stack_dir,'names.npy'))
length,bytes_pr_record = self.get_parameters(mrc_paths[0])
for row in reader:
if len(row.split()) == len(header):
names.append(row.split())
for index,i in enumerate(np.unique(labels)):
f = open(join( self.args['star'],'cluster_%s.star' %index), 'w')
with f:
self.add_params(parameter_file_path,'current_image.png',bytes_pr_record,depth,length)
DynAE(parameter_file_path,mrc_paths)
final_labels = np.load(join(self.refined,'labels.npy'))
self.write_star_file(star_files,final_labels)
def add_params(self,parameter_file_path,current_image,binary,num_particles,width):
with open(join(parameter_file_path,'parameters.csv'), 'r', newline='') as file:
writer = csv.reader(file, delimiter = '\t')
parameters = list(writer)[0]
parameters[-1] = width
parameters[-2] = num_particles
parameters[-3] = binary
parameters[-4] = current_image
with open(join(parameter_file_path,'parameters.csv'), 'w', newline='') as file:
writer = csv.writer(file, delimiter = '\t')
writer.writerow(parameters)
def write_star_file(self,star_files,labels):
names = []
for z in star_files:
with open(z, newline='') as csvfile:
reader = list(csvfile)
header = list(filter(lambda x: '_rln' in x,reader))
for row in reader:
if len(row.split()) == len(header):
names.append(row.split())
for index,i in enumerate(np.unique(labels)):
f = open(join(self.star_files,'cluster_%s.star' %index), 'w')
with f:
f.write('\n')
f.write('data_images\n')
f.write('\n')
f.write('loop_\n ')
for z in header:
f.write('\n')
f.write('data_images\n')
f.write('\n')
f.write('loop_\n ')
for z in header:
f.write(z)
f.write('')
for lab,row in zip(labels.tolist(),names):
if lab == i:
f.write(' '.join(row)+'\n')
f.close()
f.write(z)
f.write('')
for lab,row in zip(labels.tolist(),names):
if lab == i:
f.write(' '.join(row)+'\n')
f.close()
def read(self,filename,header = 1024):
def read(self,filename,header = 1024):
with open(filename,'rb') as f:
binary_header = f.read(header)
file_size = getsize(filename)
NX = np.fromstring(binary_header[0:4],np.int32)
NY = np.fromstring(binary_header[4:8],np.int32)
NZ = np.fromstring(binary_header[8:12],np.int32)
with open(filename,'rb') as f:
binary_header = f.read(header)
file_size = getsize(filename)
NX = np.fromstring(binary_header[0:4],np.int32)
NY = np.fromstring(binary_header[4:8],np.int32)
NZ = np.fromstring(binary_header[8:12],np.int32)
recordsize = int((file_size-header)/(NZ[0]))
recordsize = int((file_size-header)/(NZ[0]))
return NX[0],NY[0],NZ[0],recordsize
return NX[0],NY[0],NZ[0],recordsize
def get_parameters(self,paths):
width,length,depth,record_size_new = self.read(paths)
return length,record_size_new
def get_parameters(self,paths):
width,length,depth,record_size_new = self.read(paths)
return length,record_size_new
This diff is collapsed.
......@@ -11,95 +11,230 @@ def main():
parser = argparse.ArgumentParser(description='Run sortinator.')
parser.add_argument('--num_gpus',type=int,default = 1,
help='Number of GPUs to use.')
parser.add_argument('--gpu_list',type=str, nargs='+',default = None,
parser.add_argument('--gpu_list',type=int, nargs='+',default = None,
help='List of GPU devises, if None it will run on gpus sequentially from GPU:0 and up.')
parser.add_argument('--num_cpus',type=int,default = 8,help='The maximum allowed cpus to use for preprocessing data and Kmeans clustering')
parser.add_argument('--star', type=str, nargs='+',
help='list of path to the star files, wild cards are accepted. The star file must refer to the .mrc files')
parser.add_argument('--ab', type=int,default=100,
parser.add_argument('--batch_size', type=int,default=100,
help='deep learning model training batch')
parser.add_argument('--pb', type=int,default=200,
parser.add_argument('--p_batch_size', type=int,default=200,
help='deep learning model training batch')
parser.add_argument('--o', type=str,default='./results',
help='output directory')
parser.add_argument('--f16', type=str,default="False",
parser.add_argument('--f16', dest='f16',action='store_true',
help='Apply Tensor core acceleration to training and inference, requires compute capability of 10.0 or higher.')
parser.add_argument('--mp', type=int,default=50*10**3,
help='max amount of particle to train pr. epoch')
parser.add_argument('--vi', type=int,default=20,help='validation interval where statistics are printed out.')
parser.add_argument('--mp', type=int,default=100,
help='max amount of steps to train pr. size')
parser.add_argument('--vi', type=int,default=100,help='validation interval where statistics are printed out.')
parser.add_argument('--epochs', type=int,default=20,help='The number of epochs to iterate through the dataset, defined to have the size by the parameter --mp')
parser.add_argument('--verbose', type=str,default= 'False',help='se the performance of the model by including original class labels')
parser.add_argument('--movie_int', type=int,default=500,help='validation interval where models at full size are printed out.')
parser.add_argument('--log', type=str,default="False",help='log all possible values to file (loss, pca_components,NMI,Recall,false positives,false negatives.')
parser.add_argument('--verbose',dest='verbose', action='store_true',help='se the performance of the model by including original class labels')
parser.add_argument('--num_parts',type=int,default=4,help='Number of gaussian components to use. (This is the maximum number)')
parser.add_argument('--lr',type=float,default=0.002,help='The learning rate of the model')
parser.add_argument('--angels', type=str,default = 'True',help='do post training where you estimate the angular distribution')
parser.add_argument('--lr_b_g',type=float,default=10**(-4),help='The start learning rate of the generator')
parser.add_argument('--lr_b_d',type=float,default=10**(-4),help='The start learning rate of the descriminator')
parser.add_argument('--lr_e_g',type=float,default=10**(-4),help='The end learning rate of the generator')
parser.add_argument('--lr_e_d',type=float,default=10**(-4),help='The end learning rate of the descriminator')
parser.add_argument('--ctf', type=str,default = 'True',help='Use CTF parameters for model.')
parser.add_argument('--ctf', dest='ctf',action='store_true',default=False,help='Use CTF parameters for model.')
parser.add_argument('--noise', dest='noise',action='store_true',help='Use the noise generator for model. Set true or false boolean.')
parser.add_argument('--s_1_steps',type=int,default=5000,help='how many steps to generate the 32 x 32 model')
parser.add_argument('--s_2_steps',type=int,default=4000,help='how many steps to generate the 64 x 64 model')
parser.add_argument('--s_3_steps',type=int,default=3000,help='how many steps to generate the 128 x 128 model')
parser.add_argument('--noise', type=str,default = 'True',help='Use the noise generator for model .')
parser.add_argument('--s_4_steps',type=int,default=2000,help='how many steps to generate the 256 x 256 model')
parser.add_argument('--median_noise', type=int,default =10000,help='random vectors for fitting.')
parser.add_argument('--top_off',type=int,default=1000,help='steps to finish training at the speficied resolution')
parser.add_argument('--interpolation_count', type=int,default =10,help='interpolating images.')
parser.add_argument('--l_reg',type=float,default=0.01,help='the lambda regulization of the diversity score loss')
parser.add_argument('--angular_samples', type=int,default =10000,help='samples to make 2d histogram.')
parser.add_argument('--feature_size',type=float,default=128,help='the input feature size')
parser.add_argument('--over_cluster', dest='over_cluster',action='store_true',default=False,help='Use CTF parameters for model.')
parser.add_argument('--ang_error', dest='ang_error',action='store_true',default=False,help='Use CTF parameters for model.')
parser.add_argument('--feature_samples', type=int,default =2000,help='samples to make 2d for our z-features.')
parser.add_argument('--angular_clusters', type=int,default =8,help='number of gaussians to draw from n^2 .')
args = parser.parse_args()
if args.gpu_list == None:
gpu_list = ['/GPU:0']
else:
if len(args.gpu_list) != args.num_gpus:
print("the list of gpus to use is not the same length as the number of gpus specified. The number of gpus will be changed to match list length")
num_gpus = len(args.gpu_list)
gpu_list = []
for s in range(num_gpus):
gpu_list.append('/GPU:%i' %s)
else:
gpu_list = []
num_gpus = len(args.gpu_list)
for s in range(num_gpus):
gpu_list.append('/GPU:%i' %s)
if isinstance(args.num_gpus,int) and args.num_gpus > 0:
print("num gpus is instance of: int", isinstance(args.num_gpus,int), "and is: %i" %args.num_gpus)
else:
assert print("the number of gpus is not an integer or is less than 0")
if isinstance(args.num_cpus,int) and args.num_cpus > 0:
print("num gpus is instance of: int", isinstance(args.num_cpus,int), "and is: %i" %args.num_cpus)
else:
assert print("the number of cpus is not an integer or is less than 0")
if isinstance(args.star,list) and all(isinstance(x, str) for x in args.star):
print("star file is a list of strings",args.star)
else:
assert print("star file is not a string ")
if (isinstance(args.gpu_list,list) and all(isinstance(x, int) for x in args.gpu_list)) or args.gpu_list == None:
print("gpu list file is a list of integers",args.gpu_list)
else:
assert print("gpu list file is not a list of integers")
binary = 0
num_particles = 0
width = 0
s1 = ' '.join(gpu_list)
l = [args.ab,args.pb,args.num_parts,args.num_cpus,args.num_gpus,args.vi,' '.join(gpu_list),args.f16,args.verbose,args.epochs,args.mp,' '.join(args.star),args.lr,args.angels,args.ctf,args.noise,args.median_noise,args.interpolation_count,args.angular_samples,args.angular_clusters,args