Commit e9f52417 authored by Malthe Kjær Bisbo's avatar Malthe Kjær Bisbo
Browse files

Cleaned up cand_opp and missing missing wrap() calls

parent 531b2f5e
...@@ -94,7 +94,6 @@ class BFGSLineSearch_constrained(BFGSLineSearch): ...@@ -94,7 +94,6 @@ class BFGSLineSearch_constrained(BFGSLineSearch):
d = np.array([get_mic_distance(p1,p2,self.cell,self.pbc) for p1,p2 in zip(self.pos_init,self.atoms.get_positions())]) d = np.array([get_mic_distance(p1,p2,self.cell,self.pbc) for p1,p2 in zip(self.pos_init,self.atoms.get_positions())])
max_covDisplace = np.max(d/self.d_cov) max_covDisplace = np.max(d/self.d_cov)
if max_covDisplace > self.dmax_cov: if max_covDisplace > self.dmax_cov:
print('RELAX STOPPED on rank {} - max relax dist (>{} covDist) reached.'.format(self.rk, self.dmax_cov), flush=True)
valid_displace = False valid_displace = False
return valid_displace return valid_displace
...@@ -109,7 +108,6 @@ class BFGSLineSearch_constrained(BFGSLineSearch): ...@@ -109,7 +108,6 @@ class BFGSLineSearch_constrained(BFGSLineSearch):
indices = np.delete(np.arange(self.atoms.get_number_of_atoms()), indices_fixed) indices = np.delete(np.arange(self.atoms.get_number_of_atoms()), indices_fixed)
pos_z = self.atoms.positions[indices,2] pos_z = self.atoms.positions[indices,2]
if np.any(pos_z < self.zlim[0]) or np.any(pos_z > self.zlim[1]): if np.any(pos_z < self.zlim[0]) or np.any(pos_z > self.zlim[1]):
print('RELAXATION STOPPED on rank {} due to - zlim crossed.'.format(self.rk))
valid_z = False valid_z = False
return valid_z return valid_z
...@@ -118,6 +116,5 @@ class BFGSLineSearch_constrained(BFGSLineSearch): ...@@ -118,6 +116,5 @@ class BFGSLineSearch_constrained(BFGSLineSearch):
if self.blmin is not None or self.blmax is not None: if self.blmin is not None or self.blmax is not None:
valid_bondlengths = check_valid_bondlengths(self.atoms, self.blmin, self.blmax) valid_bondlengths = check_valid_bondlengths(self.atoms, self.blmin, self.blmax)
if not valid_bondlengths: if not valid_bondlengths:
#print('RELAXATION STOPPED on rank {} due to - invalid bondlengths.'.format(self.rk))
valid_bondlengths = False valid_bondlengths = False
return valid_bondlengths return valid_bondlengths
...@@ -9,7 +9,8 @@ from candidate_operations.candidate_generation import CandidateGenerator ...@@ -9,7 +9,8 @@ from candidate_operations.candidate_generation import CandidateGenerator
def pos_add_sphere(rattle_strength): def pos_add_sphere(rattle_strength):
# Rattle within a sphere """Help function for rattling within a sphere
"""
r = rattle_strength * np.random.rand()**(1/3) r = rattle_strength * np.random.rand()**(1/3)
theta = np.random.uniform(low=0, high=2*np.pi) theta = np.random.uniform(low=0, high=2*np.pi)
phi = np.random.uniform(low=0, high=np.pi) phi = np.random.uniform(low=0, high=np.pi)
...@@ -19,7 +20,9 @@ def pos_add_sphere(rattle_strength): ...@@ -19,7 +20,9 @@ def pos_add_sphere(rattle_strength):
return pos_add return pos_add
def pos_add_sphere_shell(rmin, rmax): def pos_add_sphere_shell(rmin, rmax):
# Rattle within a sphere """Help function for rattling atoms within a spherical
shell.
"""
r = np.random.uniform(rmin**3, rmax**3)**(1/3) r = np.random.uniform(rmin**3, rmax**3)**(1/3)
theta = np.random.uniform(low=0, high=2*np.pi) theta = np.random.uniform(low=0, high=2*np.pi)
phi = np.random.uniform(low=0, high=np.pi) phi = np.random.uniform(low=0, high=np.pi)
...@@ -52,25 +55,22 @@ class RattleMutation(CandidateGenerator): ...@@ -52,25 +55,22 @@ class RattleMutation(CandidateGenerator):
description: Name of the operation, which will be saved in description: Name of the operation, which will be saved in
info-dict of structures, on which the operation is applied. info-dict of structures, on which the operation is applied.
""" """
def __init__(self, n_top, Nrattle=3, rattle_range=3, blmin=0.7, blmax=1.4, def __init__(self, n_top, Nrattle=3, rattle_range=3,
force_all_bonds_valid=False, description='RattleMutation'): description='RattleMutation', *args, **kwargs):
CandidateGenerator.__init__(self, blmin=blmin, blmax=blmax, CandidateGenerator.__init__(self, *args, **kwargs)
force_all_bonds_valid=force_all_bonds_valid)
self.description = description self.description = description
self.n_top = n_top self.n_top = n_top
self.probability = Nrattle/n_top self.probability = Nrattle/n_top
self.rattle_range = rattle_range self.rattle_range = rattle_range
self.force_all_bonds_valid = force_all_bonds_valid
def get_new_candidate(self, parents): def operation(self, parents):
a = parents[0] a = parents[0]
a = self.rattle(a) a = self.rattle(a)
a = self.finalize(a)
return a return a
def rattle(self, atoms): def rattle(self, atoms):
"""Standardized candidate generation method for all mutation """ Rattles atoms one at a time within a sphere of radius
and crossover operations. self.rattle_range.
""" """
a = atoms.copy() a = atoms.copy()
Natoms = len(a) Natoms = len(a)
...@@ -91,18 +91,22 @@ class RattleMutation(CandidateGenerator): ...@@ -91,18 +91,22 @@ class RattleMutation(CandidateGenerator):
pos_add = pos_add_sphere(self.rattle_range) pos_add = pos_add_sphere(self.rattle_range)
a.positions[i] += pos_add a.positions[i] += pos_add
# Check position constraint
obey_constraint = self.constraints.check_if_valid(a.positions[i])
# Check if rattle was valid # Check if rattle was valid
valid_bondlengths = self.check_valid_bondlengths(a, indices=[i]) valid_bondlengths = self.check_valid_bondlengths(a, indices=[i])
if not valid_bondlengths:
valid_operation = valid_bondlengths and obey_constraint
if not valid_operation:
a.positions[i] = posi_0 a.positions[i] = posi_0
else: else:
break break
if valid_bondlengths: if valid_operation:
return a return a
else: else:
# If mutation is not successfull in supplied number # If mutation is not successfull in supplied number
# of trials, return initial structure. # of trials, return initial structure.
return atoms return None
class RattleMutation2(CandidateGenerator): class RattleMutation2(CandidateGenerator):
"""Class to perform rattle mutations on structures. """Class to perform rattle mutations on structures.
...@@ -124,25 +128,22 @@ class RattleMutation2(CandidateGenerator): ...@@ -124,25 +128,22 @@ class RattleMutation2(CandidateGenerator):
description: Name of the operation, which will be saved in description: Name of the operation, which will be saved in
info-dict of structures, on which the operation is applied. info-dict of structures, on which the operation is applied.
""" """
def __init__(self, n_top, Nrattle=3, blmin=0.7, blmax=1.4, def __init__(self, n_top, Nrattle=3, description='RattleMutation',
force_all_bonds_valid=False, description='RattleMutation'): *args, **kwargs):
CandidateGenerator.__init__(self, blmin=blmin, blmax=blmax, CandidateGenerator.__init__(self, *args, **kwargs)
force_all_bonds_valid=force_all_bonds_valid)
self.description = description self.description = description
self.n_top = n_top self.n_top = n_top
self.probability = Nrattle/n_top self.probability = Nrattle/n_top
self.force_all_bonds_valid = force_all_bonds_valid
def get_new_candidate(self, parents): def operation(self, parents):
"""Standardized candidate generation method for all mutation
and crossover operations.
"""
a = parents[0] a = parents[0]
a = self.rattle(a) a = self.rattle(a)
a = self.finalize(a)
return a return a
def rattle(self, atoms): def rattle(self, atoms):
""" Repeatedly rattles a random atom to the visinity of another
random atom.
"""
a = atoms.copy() a = atoms.copy()
Natoms = len(a) Natoms = len(a)
Nslab = Natoms - self.n_top Nslab = Natoms - self.n_top
...@@ -168,19 +169,22 @@ class RattleMutation2(CandidateGenerator): ...@@ -168,19 +169,22 @@ class RattleMutation2(CandidateGenerator):
pos_add = pos_add_sphere_shell(rmin, rmax) pos_add = pos_add_sphere_shell(rmin, rmax)
a.positions[i] = np.copy(a.positions[j]) + pos_add a.positions[i] = np.copy(a.positions[j]) + pos_add
# Check position constraint
obey_constraint = self.constraints.check_if_valid(a.positions[i])
# Check if rattle was valid # Check if rattle was valid
valid_bondlengths = self.check_valid_bondlengths(a, indices=[i]) valid_bondlengths = self.check_valid_bondlengths(a, indices=[i])
if not valid_bondlengths: valid_operation = valid_bondlengths and obey_constraint
if not valid_operation:
a.positions[i] = posi_0 a.positions[i] = posi_0
else: else:
break break
if valid_bondlengths: if valid_operation:
return a return a
else: else:
# If mutation is not successfull in supplied number # If mutation is not successfull in supplied number
# of trials, return initial structure. # of trials, return initial structure.
return atoms return None
class PermutationMutation(CandidateGenerator): class PermutationMutation(CandidateGenerator):
...@@ -204,25 +208,21 @@ class PermutationMutation(CandidateGenerator): ...@@ -204,25 +208,21 @@ class PermutationMutation(CandidateGenerator):
info-dict of structures, on which the operation is applied. info-dict of structures, on which the operation is applied.
""" """
def __init__(self, n_top, Npermute=3, blmin=0.7, blmax=1.4, def __init__(self, n_top, Npermute=3,
force_all_bonds_valid=False, description='PermutationMutation'): description='PermutationMutation', *args, **kwargs):
CandidateGenerator.__init__(self, blmin=blmin, blmax=blmax, CandidateGenerator.__init__(self, *args, **kwargs)
force_all_bonds_valid=force_all_bonds_valid)
self.description = description self.description = description
self.n_top = n_top self.n_top = n_top
self.probability = Npermute/n_top self.probability = Npermute/n_top
self.force_all_bonds_valid = force_all_bonds_valid
def get_new_candidate(self, parents): def operation(self, parents):
"""Standardized candidate generation method for all mutation
and crossover operations.
"""
a = parents[0] a = parents[0]
a = self.mutate(a) a = self.permute(a)
a = self.finalize(a)
return a return a
def mutate(self, atoms): def permute(self, atoms):
""" Permutes atoms of different type in structure.
"""
a = atoms.copy() a = atoms.copy()
Natoms = len(a) Natoms = len(a)
Nslab = Natoms - self.n_top Nslab = Natoms - self.n_top
...@@ -265,5 +265,5 @@ class PermutationMutation(CandidateGenerator): ...@@ -265,5 +265,5 @@ class PermutationMutation(CandidateGenerator):
else: else:
# If mutation is not successfull in supplied number # If mutation is not successfull in supplied number
# of trials, return initial structure. # of trials, return initial structure.
return atoms return None
...@@ -3,10 +3,11 @@ from abc import ABC, abstractmethod ...@@ -3,10 +3,11 @@ from abc import ABC, abstractmethod
from ase.data import covalent_radii from ase.data import covalent_radii
from ase.geometry import get_distances from ase.geometry import get_distances
from ase import Atoms from ase import Atoms
from ase.visualize import view from ase.visualize import view
from utils import check_valid_bondlengths from utils import check_valid_bondlengths, get_min_distances_as_fraction_of_covalent
import warnings
class CandidateGenerator(ABC): class CandidateGenerator(ABC):
"""Baseclass for mutation and crossover operations as well """Baseclass for mutation and crossover operations as well
...@@ -29,13 +30,19 @@ class CandidateGenerator(ABC): ...@@ -29,13 +30,19 @@ class CandidateGenerator(ABC):
problems with GOFEE, as GPR-relaxations and dual-steps might problems with GOFEE, as GPR-relaxations and dual-steps might
result in structures that does not obey blmin/blmax. result in structures that does not obey blmin/blmax.
""" """
def __init__(self, blmin=0.7, blmax=1.4, force_all_bonds_valid=False): def __init__(self, blmin=0.7, blmax=1.4, constraints=None,
force_all_bonds_valid=False, *args, **kwargs):
self.blmin = blmin self.blmin = blmin
self.blmax = blmax self.blmax = blmax
self.constraints = constraints
self.force_all_bonds_valid = force_all_bonds_valid self.force_all_bonds_valid = force_all_bonds_valid
self.description = 'Unspecified' self.description = 'Unspecified'
def check_valid_bondlengths(self, a, indices=None, check_too_close=True, check_isolated=True): def check_valid_bondlengths(self, a, indices=None,
check_too_close=True, check_isolated=True):
""" Method to check if bondlengths are valid according to blmin
amd blmax.
"""
if self.force_all_bonds_valid: if self.force_all_bonds_valid:
# Check all bonds (mainly for testing) # Check all bonds (mainly for testing)
return check_valid_bondlengths(a, self.blmin, self.blmax+0.1, return check_valid_bondlengths(a, self.blmin, self.blmax+0.1,
...@@ -44,30 +51,73 @@ class CandidateGenerator(ABC): ...@@ -44,30 +51,73 @@ class CandidateGenerator(ABC):
else: else:
# Check only specified ones # Check only specified ones
# (typically only for the atoms changed during operation) # (typically only for the atoms changed during operation)
return check_valid_bondlengths(a, self.blmin, self.blmax+0.1, indices=indices, return check_valid_bondlengths(a, self.blmin, self.blmax+0.1,
indices=indices,
check_too_close=check_too_close, check_too_close=check_too_close,
check_isolated=check_isolated) check_isolated=check_isolated)
def get_new_candidate(self, parents=None):
"""Standardized candidate generation method for all mutation
and crossover operations.
"""
# Check bondlengths
if parents is not None:
for i, parent in enumerate(parents):
self.check_bondlengths(parent, f'SHORT BONDS IN PARENT {i}')
for _ in range(5): # Make five tries
a = self.operation(parents)
if a is not None:
a = self.finalize(a)
break
else:
return None
return a
def train(self):
""" Method to be implemented for the operations that rely on
a Machine-Learned model to perform more informed/guided
mutation and crossover operations.
"""
pass
@abstractmethod @abstractmethod
def get_new_candidate(self): def operation(self):
pass pass
def finalize(self, a, a0=None, successfull=True): def finalize(self, a, a0=None, successfull=True):
""" Method to finalize new candidates.
"""
# Wrap positions
a.wrap()
# finalize description
if successfull: if successfull:
description = self.description description = self.description
else: else:
description = 'failed ' + self.description description = 'failed ' + self.description
# Save description
try: try:
a.info['key_value_pairs']['origin'] = description a.info['key_value_pairs']['origin'] = description
except: except:
a.info['key_value_pairs'] = {'origin': description} a.info['key_value_pairs'] = {'origin': description}
if self.force_all_bonds_valid:
# Check all bonds self.check_bondlengths(a, 'SHORT BONDS AFTER OPPERATION')
valid_bondlengths = self.check_valid_bondlengths(a)
assert valid_bondlengths, 'bondlengths are not valid'
return a return a
def check_bondlengths(self, a, warn_text):
if self.force_all_bonds_valid:
# Check all bonds
valid_bondlengths = self.check_valid_bondlengths(a)
assert valid_bondlengths, 'bondlengths are not valid'
else:
d_shortest_bond, index_shortest_bond = get_min_distances_as_fraction_of_covalent(a)
if d_shortest_bond < self.blmin:
text = f"""{warn_text}:
Atom {index_shortest_bond} has bond with d={d_shortest_bond}d_covalent"""
warnings.warn(text)
class OperationSelector(): class OperationSelector():
"""Class to produce new candidates by applying one of the """Class to produce new candidates by applying one of the
...@@ -107,11 +157,25 @@ class OperationSelector(): ...@@ -107,11 +157,25 @@ class OperationSelector():
each list of operations, if multiple are present. each list of operations, if multiple are present.
""" """
for op_list, rho_list in zip(self.operations, self.rho): for op_list, rho_list in zip(self.operations, self.rho):
to_use = self.__get_index__(rho_list) for i_trial in range(5): # Do five trials
anew = op_list[to_use].get_new_candidate(parents) to_use = self.__get_index__(rho_list)
parents[0] = anew anew = op_list[to_use].get_new_candidate(parents)
if anew is not None:
parents[0] = anew
break
else:
anew = parents[0]
anew = op_list[to_use].finalize(anew, successfull=False)
return anew return anew
def train(self, data):
""" Method to train all trainable operations in
self.operations.
"""
for oplist in self.operations:
for operation in oplist:
operation.train(data)
def random_pos(box): def random_pos(box):
""" Returns a random position within the box """ Returns a random position within the box
...@@ -151,21 +215,29 @@ class OperationConstraint(): ...@@ -151,21 +215,29 @@ class OperationConstraint():
""" Returns whether positions are valid under the """ Returns whether positions are valid under the
constraints or not. constraints or not.
""" """
if np.ndim(positions) == 1:
pos = positions.reshape(-1,3)
else:
pos = positions
if self.box is not None: if self.box is not None:
pass pass
if self.x is not None: if self.xlim is not None:
if (np.any(positions[:,0] < self.xlim[0]) or if (np.any(pos[:,0] < self.xlim[0]) or
np.any(positions[:,0] > self.xlim[1])): np.any(pos[:,0] > self.xlim[1])):
return False return False
if self.y is not None: if self.ylim is not None:
if (np.any(positions[:,1] < self.ylim[0]) or if (np.any(pos[:,1] < self.ylim[0]) or
np.any(positions[:,1] > self.ylim[1])): np.any(pos[:,1] > self.ylim[1])):
return False return False
if self.z is not None: if self.zlim is not None:
if (np.any(positions[:,2] < self.zlim[0]) or if (np.any(pos[:,2] < self.zlim[0]) or
np.any(positions[:,2] > self.zlim[1])): np.any(pos[:,2] > self.zlim[1])):
return False return False
return True
class StartGenerator(CandidateGenerator): class StartGenerator(CandidateGenerator):
""" Class used to generate random initial candidates. """ Class used to generate random initial candidates.
The candidates are generated by iteratively adding in The candidates are generated by iteratively adding in
...@@ -197,17 +269,17 @@ class StartGenerator(CandidateGenerator): ...@@ -197,17 +269,17 @@ class StartGenerator(CandidateGenerator):
False the atoms in the slab are also included. False the atoms in the slab are also included.
""" """
def __init__(self, slab, stoichiometry, box_to_place_in, def __init__(self, slab, stoichiometry, box_to_place_in,
blmin=0.7, blmax=1.4, cluster=False, description='StartGenerator'): cluster=False, description='StartGenerator',
CandidateGenerator.__init__(self, blmin=blmin, blmax=blmax) *args, **kwargs):
CandidateGenerator.__init__(self, *args, **kwargs)
self.slab = slab self.slab = slab
self.stoichiometry = stoichiometry self.stoichiometry = stoichiometry
self.box = box_to_place_in self.box = box_to_place_in
self.cluster = cluster self.cluster = cluster
self.description = description self.description = description
def get_new_candidate(self, parents=None): def operation(self, parents=None):
a = self.make_structure() a = self.make_structure()
a = self.finalize(a)
return a return a
def make_structure(self): def make_structure(self):
......
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
config: b59e7c34155208271a68e75851939773
tags: 645f666f9bcd5a90fca523b33c5a78b7
gofee.candidate\_operations package
===================================
Submodules
----------