Commit 86e35abb authored by Malthe Kjær Bisbo's avatar Malthe Kjær Bisbo
Browse files

fixed problem with constraints on candidateGenerator

parent 5de4909f
...@@ -92,9 +92,9 @@ class RattleMutation(OffspringOperation): ...@@ -92,9 +92,9 @@ class RattleMutation(OffspringOperation):
a.positions[i] += pos_add a.positions[i] += pos_add
# Check position constraint # Check position constraint
obey_constraint = self.constraints.check_if_valid(a.positions[i]) obey_constraint = self.check_constraints(a.positions[i])
# Check if rattle was valid # Check if rattle was valid
valid_bondlengths = self.check_valid_bondlengths(a, indices=[i]) valid_bondlengths = self.check_bondlengths(a, indices=[i])
valid_operation = valid_bondlengths and obey_constraint valid_operation = valid_bondlengths and obey_constraint
if not valid_operation: if not valid_operation:
...@@ -170,9 +170,9 @@ class RattleMutation2(OffspringOperation): ...@@ -170,9 +170,9 @@ class RattleMutation2(OffspringOperation):
a.positions[i] = np.copy(a.positions[j]) + pos_add a.positions[i] = np.copy(a.positions[j]) + pos_add
# Check position constraint # Check position constraint
obey_constraint = self.constraints.check_if_valid(a.positions[i]) obey_constraint = self.check_constraints(a.positions[i])
# Check if rattle was valid # Check if rattle was valid
valid_bondlengths = self.check_valid_bondlengths(a, indices=[i]) valid_bondlengths = self.check_bondlengths(a, indices=[i])
valid_operation = valid_bondlengths and obey_constraint valid_operation = valid_bondlengths and obey_constraint
if not valid_operation: if not valid_operation:
...@@ -253,7 +253,7 @@ class PermutationMutation(OffspringOperation): ...@@ -253,7 +253,7 @@ class PermutationMutation(OffspringOperation):
a.positions[j] = pos_i a.positions[j] = pos_i
# Check if rattle was valid # Check if rattle was valid
valid_bondlengths = self.check_valid_bondlengths(a, indices=[i,j]) valid_bondlengths = self.check_bondlengths(a, indices=[i,j])
if not valid_bondlengths: if not valid_bondlengths:
a.positions[i] = pos_i a.positions[i] = pos_i
......
...@@ -38,7 +38,7 @@ class OffspringOperation(ABC): ...@@ -38,7 +38,7 @@ class OffspringOperation(ABC):
self.force_all_bonds_valid = force_all_bonds_valid self.force_all_bonds_valid = force_all_bonds_valid
self.description = 'Unspecified' self.description = 'Unspecified'
def check_valid_bondlengths(self, a, indices=None, def check_bondlengths(self, a, indices=None,
check_too_close=True, check_isolated=True): check_too_close=True, check_isolated=True):
""" Method to check if bondlengths are valid according to blmin """ Method to check if bondlengths are valid according to blmin
amd blmax. amd blmax.
...@@ -63,7 +63,7 @@ class OffspringOperation(ABC): ...@@ -63,7 +63,7 @@ class OffspringOperation(ABC):
# Check bondlengths # Check bondlengths
if parents is not None: if parents is not None:
for i, parent in enumerate(parents): for i, parent in enumerate(parents):
self.check_bondlengths(parent, f'SHORT BONDS IN PARENT {i}') self.check_all_bondlengths(parent, f'SHORT BONDS IN PARENT {i}')
for _ in range(5): # Make five tries for _ in range(5): # Make five tries
a = self.operation(parents) a = self.operation(parents)
...@@ -103,13 +103,13 @@ class OffspringOperation(ABC): ...@@ -103,13 +103,13 @@ class OffspringOperation(ABC):
except: except:
a.info['key_value_pairs'] = {'origin': description} a.info['key_value_pairs'] = {'origin': description}
self.check_bondlengths(a, 'SHORT BONDS AFTER OPPERATION') self.check_all_bondlengths(a, 'SHORT BONDS AFTER OPPERATION')
return a return a
def check_bondlengths(self, a, warn_text): def check_all_bondlengths(self, a, warn_text):
if self.force_all_bonds_valid: if self.force_all_bonds_valid:
# Check all bonds # Check all bonds
valid_bondlengths = self.check_valid_bondlengths(a) valid_bondlengths = self.check_bondlengths(a)
assert valid_bondlengths, 'bondlengths are not valid' assert valid_bondlengths, 'bondlengths are not valid'
else: else:
d_shortest_bond, index_shortest_bond = get_min_distances_as_fraction_of_covalent(a) d_shortest_bond, index_shortest_bond = get_min_distances_as_fraction_of_covalent(a)
...@@ -118,6 +118,13 @@ class OffspringOperation(ABC): ...@@ -118,6 +118,13 @@ class OffspringOperation(ABC):
Atom {index_shortest_bond} has bond with d={d_shortest_bond}d_covalent""" Atom {index_shortest_bond} has bond with d={d_shortest_bond}d_covalent"""
warnings.warn(text) warnings.warn(text)
def check_constraints(self, indices=None):
if self.constraints is not None:
valid = self.constraints.check_if_valid(indices)
return valid
else:
return True
class CandidateGenerator(): class CandidateGenerator():
"""Class to produce new candidates by applying one of the """Class to produce new candidates by applying one of the
candidate generation operations which is supplied in the candidate generation operations which is supplied in the
...@@ -297,18 +304,18 @@ class StartGenerator(OffspringOperation): ...@@ -297,18 +304,18 @@ class StartGenerator(OffspringOperation):
a += Atoms([num[i]], posi.reshape(1,3)) a += Atoms([num[i]], posi.reshape(1,3))
# Check if position of new atom is valid # Check if position of new atom is valid
not_too_close = self.check_valid_bondlengths(a, indices=[Nslab+i], not_too_close = self.check_bondlengths(a, indices=[Nslab+i],
check_too_close=True, check_too_close=True,
check_isolated=False) check_isolated=False)
if len(a) == 1: # The first atom if len(a) == 1: # The first atom
not_isolated = True not_isolated = True
else: else:
if self.cluster: # Check isolation excluding slab atoms. if self.cluster: # Check isolation excluding slab atoms.
not_isolated = self.check_valid_bondlengths(a[Nslab:], indices=[Nslab+i], not_isolated = self.check_bondlengths(a[Nslab:], indices=[Nslab+i],
check_too_close=False, check_too_close=False,
check_isolated=True) check_isolated=True)
else: # All atoms. else: # All atoms.
not_isolated = self.check_valid_bondlengths(a, indices=[Nslab+i], not_isolated = self.check_bondlengths(a, indices=[Nslab+i],
check_too_close=False, check_too_close=False,
check_isolated=True) check_isolated=True)
valid_bondlengths = not_too_close and not_isolated valid_bondlengths = not_too_close and not_isolated
...@@ -365,7 +372,7 @@ if __name__ == '__main__': ...@@ -365,7 +372,7 @@ if __name__ == '__main__':
""" """
for a in traj: for a in traj:
vb = rattle.check_valid_bondlengths(a) vb = rattle.check_bondlengths(a)
print(vb) print(vb)
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment