Commit e9f52417 authored by Malthe Kjær Bisbo's avatar Malthe Kjær Bisbo
Browse files

Cleaned up cand_opp and missing missing wrap() calls

parent 531b2f5e
......@@ -94,7 +94,6 @@ class BFGSLineSearch_constrained(BFGSLineSearch):
d = np.array([get_mic_distance(p1,p2,self.cell,self.pbc) for p1,p2 in zip(self.pos_init,self.atoms.get_positions())])
max_covDisplace = np.max(d/self.d_cov)
if max_covDisplace > self.dmax_cov:
print('RELAX STOPPED on rank {} - max relax dist (>{} covDist) reached.'.format(self.rk, self.dmax_cov), flush=True)
valid_displace = False
return valid_displace
......@@ -109,7 +108,6 @@ class BFGSLineSearch_constrained(BFGSLineSearch):
indices = np.delete(np.arange(self.atoms.get_number_of_atoms()), indices_fixed)
pos_z = self.atoms.positions[indices,2]
if np.any(pos_z < self.zlim[0]) or np.any(pos_z > self.zlim[1]):
print('RELAXATION STOPPED on rank {} due to - zlim crossed.'.format(self.rk))
valid_z = False
return valid_z
......@@ -118,6 +116,5 @@ class BFGSLineSearch_constrained(BFGSLineSearch):
if self.blmin is not None or self.blmax is not None:
valid_bondlengths = check_valid_bondlengths(self.atoms, self.blmin, self.blmax)
if not valid_bondlengths:
#print('RELAXATION STOPPED on rank {} due to - invalid bondlengths.'.format(self.rk))
valid_bondlengths = False
return valid_bondlengths
......@@ -9,7 +9,8 @@ from candidate_operations.candidate_generation import CandidateGenerator
def pos_add_sphere(rattle_strength):
# Rattle within a sphere
"""Help function for rattling within a sphere
"""
r = rattle_strength * np.random.rand()**(1/3)
theta = np.random.uniform(low=0, high=2*np.pi)
phi = np.random.uniform(low=0, high=np.pi)
......@@ -19,7 +20,9 @@ def pos_add_sphere(rattle_strength):
return pos_add
def pos_add_sphere_shell(rmin, rmax):
# Rattle within a sphere
"""Help function for rattling atoms within a spherical
shell.
"""
r = np.random.uniform(rmin**3, rmax**3)**(1/3)
theta = np.random.uniform(low=0, high=2*np.pi)
phi = np.random.uniform(low=0, high=np.pi)
......@@ -52,25 +55,22 @@ class RattleMutation(CandidateGenerator):
description: Name of the operation, which will be saved in
info-dict of structures, on which the operation is applied.
"""
def __init__(self, n_top, Nrattle=3, rattle_range=3, blmin=0.7, blmax=1.4,
force_all_bonds_valid=False, description='RattleMutation'):
CandidateGenerator.__init__(self, blmin=blmin, blmax=blmax,
force_all_bonds_valid=force_all_bonds_valid)
def __init__(self, n_top, Nrattle=3, rattle_range=3,
description='RattleMutation', *args, **kwargs):
CandidateGenerator.__init__(self, *args, **kwargs)
self.description = description
self.n_top = n_top
self.probability = Nrattle/n_top
self.rattle_range = rattle_range
self.force_all_bonds_valid = force_all_bonds_valid
def get_new_candidate(self, parents):
def operation(self, parents):
a = parents[0]
a = self.rattle(a)
a = self.finalize(a)
return a
def rattle(self, atoms):
"""Standardized candidate generation method for all mutation
and crossover operations.
""" Rattles atoms one at a time within a sphere of radius
self.rattle_range.
"""
a = atoms.copy()
Natoms = len(a)
......@@ -91,18 +91,22 @@ class RattleMutation(CandidateGenerator):
pos_add = pos_add_sphere(self.rattle_range)
a.positions[i] += pos_add
# Check position constraint
obey_constraint = self.constraints.check_if_valid(a.positions[i])
# Check if rattle was valid
valid_bondlengths = self.check_valid_bondlengths(a, indices=[i])
if not valid_bondlengths:
valid_operation = valid_bondlengths and obey_constraint
if not valid_operation:
a.positions[i] = posi_0
else:
break
if valid_bondlengths:
if valid_operation:
return a
else:
# If mutation is not successfull in supplied number
# of trials, return initial structure.
return atoms
return None
class RattleMutation2(CandidateGenerator):
"""Class to perform rattle mutations on structures.
......@@ -124,25 +128,22 @@ class RattleMutation2(CandidateGenerator):
description: Name of the operation, which will be saved in
info-dict of structures, on which the operation is applied.
"""
def __init__(self, n_top, Nrattle=3, blmin=0.7, blmax=1.4,
force_all_bonds_valid=False, description='RattleMutation'):
CandidateGenerator.__init__(self, blmin=blmin, blmax=blmax,
force_all_bonds_valid=force_all_bonds_valid)
def __init__(self, n_top, Nrattle=3, description='RattleMutation',
*args, **kwargs):
CandidateGenerator.__init__(self, *args, **kwargs)
self.description = description
self.n_top = n_top
self.probability = Nrattle/n_top
self.force_all_bonds_valid = force_all_bonds_valid
def get_new_candidate(self, parents):
"""Standardized candidate generation method for all mutation
and crossover operations.
"""
def operation(self, parents):
a = parents[0]
a = self.rattle(a)
a = self.finalize(a)
return a
def rattle(self, atoms):
""" Repeatedly rattles a random atom to the visinity of another
random atom.
"""
a = atoms.copy()
Natoms = len(a)
Nslab = Natoms - self.n_top
......@@ -168,19 +169,22 @@ class RattleMutation2(CandidateGenerator):
pos_add = pos_add_sphere_shell(rmin, rmax)
a.positions[i] = np.copy(a.positions[j]) + pos_add
# Check position constraint
obey_constraint = self.constraints.check_if_valid(a.positions[i])
# Check if rattle was valid
valid_bondlengths = self.check_valid_bondlengths(a, indices=[i])
if not valid_bondlengths:
valid_operation = valid_bondlengths and obey_constraint
if not valid_operation:
a.positions[i] = posi_0
else:
break
if valid_bondlengths:
if valid_operation:
return a
else:
# If mutation is not successfull in supplied number
# of trials, return initial structure.
return atoms
return None
class PermutationMutation(CandidateGenerator):
......@@ -204,25 +208,21 @@ class PermutationMutation(CandidateGenerator):
info-dict of structures, on which the operation is applied.
"""
def __init__(self, n_top, Npermute=3, blmin=0.7, blmax=1.4,
force_all_bonds_valid=False, description='PermutationMutation'):
CandidateGenerator.__init__(self, blmin=blmin, blmax=blmax,
force_all_bonds_valid=force_all_bonds_valid)
def __init__(self, n_top, Npermute=3,
description='PermutationMutation', *args, **kwargs):
CandidateGenerator.__init__(self, *args, **kwargs)
self.description = description
self.n_top = n_top
self.probability = Npermute/n_top
self.force_all_bonds_valid = force_all_bonds_valid
def get_new_candidate(self, parents):
"""Standardized candidate generation method for all mutation
and crossover operations.
"""
def operation(self, parents):
a = parents[0]
a = self.mutate(a)
a = self.finalize(a)
a = self.permute(a)
return a
def mutate(self, atoms):
def permute(self, atoms):
""" Permutes atoms of different type in structure.
"""
a = atoms.copy()
Natoms = len(a)
Nslab = Natoms - self.n_top
......@@ -265,5 +265,5 @@ class PermutationMutation(CandidateGenerator):
else:
# If mutation is not successfull in supplied number
# of trials, return initial structure.
return atoms
return None
......@@ -3,10 +3,11 @@ from abc import ABC, abstractmethod
from ase.data import covalent_radii
from ase.geometry import get_distances
from ase import Atoms
from ase.visualize import view
from utils import check_valid_bondlengths
from utils import check_valid_bondlengths, get_min_distances_as_fraction_of_covalent
import warnings
class CandidateGenerator(ABC):
"""Baseclass for mutation and crossover operations as well
......@@ -29,13 +30,19 @@ class CandidateGenerator(ABC):
problems with GOFEE, as GPR-relaxations and dual-steps might
result in structures that does not obey blmin/blmax.
"""
def __init__(self, blmin=0.7, blmax=1.4, force_all_bonds_valid=False):
def __init__(self, blmin=0.7, blmax=1.4, constraints=None,
force_all_bonds_valid=False, *args, **kwargs):
self.blmin = blmin
self.blmax = blmax
self.constraints = constraints
self.force_all_bonds_valid = force_all_bonds_valid
self.description = 'Unspecified'
def check_valid_bondlengths(self, a, indices=None, check_too_close=True, check_isolated=True):
def check_valid_bondlengths(self, a, indices=None,
check_too_close=True, check_isolated=True):
""" Method to check if bondlengths are valid according to blmin
amd blmax.
"""
if self.force_all_bonds_valid:
# Check all bonds (mainly for testing)
return check_valid_bondlengths(a, self.blmin, self.blmax+0.1,
......@@ -44,30 +51,73 @@ class CandidateGenerator(ABC):
else:
# Check only specified ones
# (typically only for the atoms changed during operation)
return check_valid_bondlengths(a, self.blmin, self.blmax+0.1, indices=indices,
return check_valid_bondlengths(a, self.blmin, self.blmax+0.1,
indices=indices,
check_too_close=check_too_close,
check_isolated=check_isolated)
def get_new_candidate(self, parents=None):
"""Standardized candidate generation method for all mutation
and crossover operations.
"""
# Check bondlengths
if parents is not None:
for i, parent in enumerate(parents):
self.check_bondlengths(parent, f'SHORT BONDS IN PARENT {i}')
for _ in range(5): # Make five tries
a = self.operation(parents)
if a is not None:
a = self.finalize(a)
break
else:
return None
return a
def train(self):
""" Method to be implemented for the operations that rely on
a Machine-Learned model to perform more informed/guided
mutation and crossover operations.
"""
pass
@abstractmethod
def get_new_candidate(self):
def operation(self):
pass
def finalize(self, a, a0=None, successfull=True):
""" Method to finalize new candidates.
"""
# Wrap positions
a.wrap()
# finalize description
if successfull:
description = self.description
else:
description = 'failed ' + self.description
# Save description
try:
a.info['key_value_pairs']['origin'] = description
except:
a.info['key_value_pairs'] = {'origin': description}
if self.force_all_bonds_valid:
# Check all bonds
valid_bondlengths = self.check_valid_bondlengths(a)
assert valid_bondlengths, 'bondlengths are not valid'
self.check_bondlengths(a, 'SHORT BONDS AFTER OPPERATION')
return a
def check_bondlengths(self, a, warn_text):
if self.force_all_bonds_valid:
# Check all bonds
valid_bondlengths = self.check_valid_bondlengths(a)
assert valid_bondlengths, 'bondlengths are not valid'
else:
d_shortest_bond, index_shortest_bond = get_min_distances_as_fraction_of_covalent(a)
if d_shortest_bond < self.blmin:
text = f"""{warn_text}:
Atom {index_shortest_bond} has bond with d={d_shortest_bond}d_covalent"""
warnings.warn(text)
class OperationSelector():
"""Class to produce new candidates by applying one of the
......@@ -107,11 +157,25 @@ class OperationSelector():
each list of operations, if multiple are present.
"""
for op_list, rho_list in zip(self.operations, self.rho):
to_use = self.__get_index__(rho_list)
anew = op_list[to_use].get_new_candidate(parents)
parents[0] = anew
for i_trial in range(5): # Do five trials
to_use = self.__get_index__(rho_list)
anew = op_list[to_use].get_new_candidate(parents)
if anew is not None:
parents[0] = anew
break
else:
anew = parents[0]
anew = op_list[to_use].finalize(anew, successfull=False)
return anew
def train(self, data):
""" Method to train all trainable operations in
self.operations.
"""
for oplist in self.operations:
for operation in oplist:
operation.train(data)
def random_pos(box):
""" Returns a random position within the box
......@@ -151,21 +215,29 @@ class OperationConstraint():
""" Returns whether positions are valid under the
constraints or not.
"""
if np.ndim(positions) == 1:
pos = positions.reshape(-1,3)
else:
pos = positions
if self.box is not None:
pass
if self.x is not None:
if (np.any(positions[:,0] < self.xlim[0]) or
np.any(positions[:,0] > self.xlim[1])):
if self.xlim is not None:
if (np.any(pos[:,0] < self.xlim[0]) or
np.any(pos[:,0] > self.xlim[1])):
return False
if self.y is not None:
if (np.any(positions[:,1] < self.ylim[0]) or
np.any(positions[:,1] > self.ylim[1])):
if self.ylim is not None:
if (np.any(pos[:,1] < self.ylim[0]) or
np.any(pos[:,1] > self.ylim[1])):
return False
if self.z is not None:
if (np.any(positions[:,2] < self.zlim[0]) or
np.any(positions[:,2] > self.zlim[1])):
if self.zlim is not None:
if (np.any(pos[:,2] < self.zlim[0]) or
np.any(pos[:,2] > self.zlim[1])):
return False
return True
class StartGenerator(CandidateGenerator):
""" Class used to generate random initial candidates.
The candidates are generated by iteratively adding in
......@@ -197,17 +269,17 @@ class StartGenerator(CandidateGenerator):
False the atoms in the slab are also included.
"""
def __init__(self, slab, stoichiometry, box_to_place_in,
blmin=0.7, blmax=1.4, cluster=False, description='StartGenerator'):
CandidateGenerator.__init__(self, blmin=blmin, blmax=blmax)
cluster=False, description='StartGenerator',
*args, **kwargs):
CandidateGenerator.__init__(self, *args, **kwargs)
self.slab = slab
self.stoichiometry = stoichiometry
self.box = box_to_place_in
self.cluster = cluster
self.description = description
def get_new_candidate(self, parents=None):
def operation(self, parents=None):
a = self.make_structure()
a = self.finalize(a)
return a
def make_structure(self):
......
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
config: b59e7c34155208271a68e75851939773
tags: 645f666f9bcd5a90fca523b33c5a78b7
gofee.candidate\_operations package
===================================
Submodules
----------
gofee.candidate\_operations.basic\_mutations module
---------------------------------------------------
.. automodule:: gofee.candidate_operations.basic_mutations
:members:
:undoc-members:
:show-inheritance:
gofee.candidate\_operations.candidate\_generation module
--------------------------------------------------------
.. automodule:: gofee.candidate_operations.candidate_generation
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: gofee.candidate_operations
:members:
:undoc-members:
:show-inheritance:
gofee package
=============
Subpackages
-----------
.. toctree::
gofee.candidate_operations
gofee.surrogate
Submodules
----------
gofee.bfgslinesearch\_constrained module
----------------------------------------
.. automodule:: gofee.bfgslinesearch_constrained
:members:
:undoc-members:
:show-inheritance:
gofee.bfgslinesearch\_zlim module
---------------------------------
.. automodule:: gofee.bfgslinesearch_zlim
:members:
:undoc-members:
:show-inheritance:
gofee.gofee module
------------------
.. automodule:: gofee.gofee
:members:
:undoc-members:
:show-inheritance:
gofee.parallel\_utils module
----------------------------
.. automodule:: gofee.parallel_utils
:members:
:undoc-members:
:show-inheritance:
gofee.population module
-----------------------
.. automodule:: gofee.population
:members:
:undoc-members:
:show-inheritance:
gofee.utils module
------------------
.. automodule:: gofee.utils
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: gofee
:members:
:undoc-members:
:show-inheritance:
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment