bfgslinesearch_constrained.py 4.71 KB
Newer Older
1
2
3
4
5
6
7
8
import numpy as np
from ase.optimize.bfgslinesearch import BFGSLineSearch
from ase.constraints import FixAtoms
from ase.data import covalent_radii
from ase.ga.utilities import get_mic_distance

from utils import check_valid_bondlengths

9
import traceback
Malthe Kjær Bisbo's avatar
Malthe Kjær Bisbo committed
10
import sys
11

12
13
14
15

class BFGSLineSearch_constrained(BFGSLineSearch):
    def __init__(self, atoms, pos_init=None, restart=None, logfile='-', maxstep=.2,
                 trajectory=None, c1=0.23, c2=0.46, alpha=10.0, stpmax=50.0,
16
17
18
                 master=None, force_consistent=None,
                 blmin=None, blmax=None, max_relax_dist=4.0, 
                 position_constraint=None, rk=None):
19
20
21
        """
        add maximum displacement of single atoms to BFGSLineSearch:

22
        max_relax_dist: maximum distance the atom is alowed to move from it's initial position.
23
24
25
26
27
28
29
        in units of it's covalent distance.
        """

        self.rk = rk  # for testing
        
        self.blmin = blmin
        self.blmax = blmax
30
        self.position_constraint=position_constraint
31
        
32
        self.max_relax_dist = max_relax_dist
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        if pos_init is not None:
            self.pos_init = pos_init
        else:
            self.pos_init = np.copy(atoms.positions)

        self.cell = atoms.get_cell()
        self.pbc = atoms.get_pbc()

        BFGSLineSearch.__init__(self, atoms, restart=restart, logfile=logfile, maxstep=maxstep,
                                trajectory=trajectory, c1=c1, c2=c2, alpha=alpha, stpmax=stpmax,
                                master=master, force_consistent=force_consistent)

    def converged(self, forces=None):
        """Did the optimization converge?"""
        if forces is None:
            forces = self.atoms.get_forces()
        if hasattr(self.atoms, 'get_curvature'):
            return ((forces**2).sum(axis=1).max() < self.fmax**2 and
                    self.atoms.get_curvature() < 0.0)

        # Check constraints
        terminate_due_to_constraints = self.check_constraints()
        if terminate_due_to_constraints:
            return True
        
        return (forces**2).sum(axis=1).max() < self.fmax**2

    def check_constraints(self):
        # Check if stop due to large displacement
        valid_displace = self.check_displacement()
63
64
        # Check if stop due to position-constraint
        valid_pos = self.check_positions()
65
66
67
        # Check if stop due to invalid bondlengths
        valid_bondlengths = self.check_bondlengths()
        
68
        if not valid_displace or not valid_pos or not valid_bondlengths:
69
70
71
72
73
            #self.r0 = self.atoms_prior.get_positions()
            return True
    
    def check_displacement(self):
        valid_displace = True
74
75
76
77
        if self.max_relax_dist is not None:
            d_relax = np.array([get_mic_distance(p1,p2,self.cell,self.pbc) 
                                for p1,p2 in zip(self.pos_init,self.atoms.get_positions())])
            if np.any(d_relax > self.max_relax_dist):
78
79
80
                valid_displace = False
        return valid_displace

81
82
83
    def check_positions(self):
        valid_pos = True
        if self.position_constraint is not None:
84
85
86
87
88
            # get indices of non-fixed atoms
            indices = np.arange(self.atoms.get_number_of_atoms())
            for constraint in self.atoms.constraints:
                if isinstance(constraint, FixAtoms):
                    indices_fixed = constraint.get_indices()
89
90
91
92
                    indices_not_fixed = np.delete(np.arange(self.atoms.get_number_of_atoms()), indices_fixed)
            pos_not_fixed = self.atoms.positions[indices_not_fixed]
            valid_pos = self.position_constraint.check_if_valid(pos_not_fixed)
        return valid_pos
93
94
95
96
97
98
99
100

    def check_bondlengths(self):
        valid_bondlengths = True
        if self.blmin is not None or self.blmax is not None:
            valid_bondlengths = check_valid_bondlengths(self.atoms, self.blmin, self.blmax)
            if not valid_bondlengths:
                valid_bondlengths = False
        return valid_bondlengths
Malthe Kjær Bisbo's avatar
Malthe Kjær Bisbo committed
101

Malthe Kjær Bisbo's avatar
Malthe Kjær Bisbo committed
102

103
def relax(structure, calc, Fmax=0.05, steps_max=200, max_relax_dist=None, position_constraint=None):
Malthe Kjær Bisbo's avatar
Malthe Kjær Bisbo committed
104
105
106
107
108
109
110
111
112
113
    a = structure.copy()
    # Set calculator 
    a.set_calculator(calc)
    pos_init = a.get_positions()

    # Catch if linesearch fails
    try:
        dyn = BFGSLineSearch_constrained(a,
                                         logfile=None,
                                         pos_init=pos_init,
114
115
                                         max_relax_dist=max_relax_dist,
                                         position_constraint=position_constraint)
Malthe Kjær Bisbo's avatar
Malthe Kjær Bisbo committed
116
117
118
119
120
121
        dyn.run(fmax = Fmax, steps = steps_max)
    except Exception as err:
        print('Error in surrogate-relaxation:', err, flush=True)
        traceback.print_exc()
        traceback.print_exc(file=sys.stderr)
    return a