Commit 86e35abb authored by Malthe Kjær Bisbo's avatar Malthe Kjær Bisbo
Browse files

fixed problem with constraints on candidateGenerator

parent 5de4909f
......@@ -92,9 +92,9 @@ class RattleMutation(OffspringOperation):
a.positions[i] += pos_add
# Check position constraint
obey_constraint = self.constraints.check_if_valid(a.positions[i])
obey_constraint = self.check_constraints(a.positions[i])
# Check if rattle was valid
valid_bondlengths = self.check_valid_bondlengths(a, indices=[i])
valid_bondlengths = self.check_bondlengths(a, indices=[i])
valid_operation = valid_bondlengths and obey_constraint
if not valid_operation:
......@@ -170,9 +170,9 @@ class RattleMutation2(OffspringOperation):
a.positions[i] = np.copy(a.positions[j]) + pos_add
# Check position constraint
obey_constraint = self.constraints.check_if_valid(a.positions[i])
obey_constraint = self.check_constraints(a.positions[i])
# Check if rattle was valid
valid_bondlengths = self.check_valid_bondlengths(a, indices=[i])
valid_bondlengths = self.check_bondlengths(a, indices=[i])
valid_operation = valid_bondlengths and obey_constraint
if not valid_operation:
......@@ -253,7 +253,7 @@ class PermutationMutation(OffspringOperation):
a.positions[j] = pos_i
# Check if rattle was valid
valid_bondlengths = self.check_valid_bondlengths(a, indices=[i,j])
valid_bondlengths = self.check_bondlengths(a, indices=[i,j])
if not valid_bondlengths:
a.positions[i] = pos_i
......
......@@ -38,7 +38,7 @@ class OffspringOperation(ABC):
self.force_all_bonds_valid = force_all_bonds_valid
self.description = 'Unspecified'
def check_valid_bondlengths(self, a, indices=None,
def check_bondlengths(self, a, indices=None,
check_too_close=True, check_isolated=True):
""" Method to check if bondlengths are valid according to blmin
amd blmax.
......@@ -63,7 +63,7 @@ class OffspringOperation(ABC):
# Check bondlengths
if parents is not None:
for i, parent in enumerate(parents):
self.check_bondlengths(parent, f'SHORT BONDS IN PARENT {i}')
self.check_all_bondlengths(parent, f'SHORT BONDS IN PARENT {i}')
for _ in range(5): # Make five tries
a = self.operation(parents)
......@@ -103,13 +103,13 @@ class OffspringOperation(ABC):
except:
a.info['key_value_pairs'] = {'origin': description}
self.check_bondlengths(a, 'SHORT BONDS AFTER OPPERATION')
self.check_all_bondlengths(a, 'SHORT BONDS AFTER OPPERATION')
return a
def check_bondlengths(self, a, warn_text):
def check_all_bondlengths(self, a, warn_text):
if self.force_all_bonds_valid:
# Check all bonds
valid_bondlengths = self.check_valid_bondlengths(a)
valid_bondlengths = self.check_bondlengths(a)
assert valid_bondlengths, 'bondlengths are not valid'
else:
d_shortest_bond, index_shortest_bond = get_min_distances_as_fraction_of_covalent(a)
......@@ -118,6 +118,13 @@ class OffspringOperation(ABC):
Atom {index_shortest_bond} has bond with d={d_shortest_bond}d_covalent"""
warnings.warn(text)
def check_constraints(self, indices=None):
if self.constraints is not None:
valid = self.constraints.check_if_valid(indices)
return valid
else:
return True
class CandidateGenerator():
"""Class to produce new candidates by applying one of the
candidate generation operations which is supplied in the
......@@ -297,18 +304,18 @@ class StartGenerator(OffspringOperation):
a += Atoms([num[i]], posi.reshape(1,3))
# Check if position of new atom is valid
not_too_close = self.check_valid_bondlengths(a, indices=[Nslab+i],
not_too_close = self.check_bondlengths(a, indices=[Nslab+i],
check_too_close=True,
check_isolated=False)
if len(a) == 1: # The first atom
not_isolated = True
else:
if self.cluster: # Check isolation excluding slab atoms.
not_isolated = self.check_valid_bondlengths(a[Nslab:], indices=[Nslab+i],
not_isolated = self.check_bondlengths(a[Nslab:], indices=[Nslab+i],
check_too_close=False,
check_isolated=True)
else: # All atoms.
not_isolated = self.check_valid_bondlengths(a, indices=[Nslab+i],
not_isolated = self.check_bondlengths(a, indices=[Nslab+i],
check_too_close=False,
check_isolated=True)
valid_bondlengths = not_too_close and not_isolated
......@@ -365,7 +372,7 @@ if __name__ == '__main__':
"""
for a in traj:
vb = rattle.check_valid_bondlengths(a)
vb = rattle.check_bondlengths(a)
print(vb)
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment