tSNESProblem.hh - pism - [fork] customized build of PISM, the parallel ice sheet model (tillflux branch)
HTML git clone git://src.adamsgaard.dk/pism
DIR Log
DIR Files
DIR Refs
DIR LICENSE
---
tSNESProblem.hh (5414B)
---
1 // Copyright (C) 2011, 2012, 2014, 2015, 2016, 2017 David Maxwell
2 //
3 // This file is part of PISM.
4 //
5 // PISM is free software; you can redistribute it and/or modify it under the
6 // terms of the GNU General Public License as published by the Free Software
7 // Foundation; either version 3 of the License, or (at your option) any later
8 // version.
9 //
10 // PISM is distributed in the hope that it will be useful, but WITHOUT ANY
11 // WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
12 // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
13 // details.
14 //
15 // You should have received a copy of the GNU General Public License
16 // along with PISM; if not, write to the Free Software
17 // Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
18
19
20 #ifndef _SNESPROBLEM_H_
21 #define _SNESPROBLEM_H_
22
23 #include "pism/util/IceGrid.hh" // inline implementation in the header uses IceGrid
24 #include "pism/util/Vector2.hh" // to get Vector2
25 #include "pism/util/petscwrappers/SNES.hh"
26 #include "pism/util/Logger.hh"
27
28 namespace pism {
29
30 template<int DOF, class U> class SNESProblem {
31 public:
32 SNESProblem(IceGrid::ConstPtr g);
33
34 virtual ~SNESProblem();
35
36 virtual void solve();
37
38 virtual const std::string& name();
39
40 virtual Vec solution()
41 {
42 return m_X;
43 }
44
45 protected:
46
47 virtual void compute_local_function(DMDALocalInfo *info, const U **xg, U **yg) = 0;
48 virtual void compute_local_jacobian(DMDALocalInfo *info, const U **x, Mat B) = 0;
49
50 IceGrid::ConstPtr m_grid;
51
52 petsc::Vec m_X;
53 petsc::SNES m_snes;
54 petsc::DM::Ptr m_DA;
55
56 private:
57
58 struct CallbackData {
59 DM da;
60 SNESProblem<DOF,U> *solver;
61 };
62
63 CallbackData m_callbackData;
64
65 static PetscErrorCode function_callback(DMDALocalInfo *info, const U **x, U **f,
66 CallbackData *);
67 static PetscErrorCode jacobian_callback(DMDALocalInfo *info, const U **x, Mat B,
68 CallbackData *);
69 };
70
71 typedef SNESProblem<1,double> SNESScalarProblem;
72 typedef SNESProblem<2,Vector2> SNESVectorProblem;
73
74 template<int DOF, class U>
75 PetscErrorCode SNESProblem<DOF,U>::function_callback(DMDALocalInfo *info,
76 const U **x, U **f,
77 SNESProblem<DOF,U>::CallbackData *cb) {
78 try {
79 cb->solver->compute_local_function(info,x,f);
80 } catch (...) {
81 MPI_Comm com = MPI_COMM_SELF;
82 PetscErrorCode ierr = PetscObjectGetComm((PetscObject)cb->da, &com); CHKERRQ(ierr);
83 handle_fatal_errors(com);
84 SETERRQ(com, 1, "A PISM callback failed");
85 }
86 return 0;
87 }
88
89 template<int DOF, class U>
90 PetscErrorCode SNESProblem<DOF,U>::jacobian_callback(DMDALocalInfo *info,
91 const U **x, Mat J,
92 SNESProblem<DOF,U>::CallbackData *cb) {
93 try {
94 cb->solver->compute_local_jacobian(info, x, J);
95 } catch (...) {
96 MPI_Comm com = MPI_COMM_SELF;
97 PetscErrorCode ierr = PetscObjectGetComm((PetscObject)cb->da, &com); CHKERRQ(ierr);
98 handle_fatal_errors(com);
99 SETERRQ(com, 1, "A PISM callback failed");
100 }
101 return 0;
102 }
103
104 template<int DOF, class U>
105 SNESProblem<DOF, U>::SNESProblem(IceGrid::ConstPtr g)
106 : m_grid(g) {
107
108 PetscErrorCode ierr;
109
110 int stencil_width=1;
111 m_DA = m_grid->get_dm(DOF, stencil_width);
112
113 ierr = DMCreateGlobalVector(*m_DA, m_X.rawptr());
114 PISM_CHK(ierr, "DMCreateGlobalVector");
115
116 ierr = SNESCreate(m_grid->com, m_snes.rawptr());
117 PISM_CHK(ierr, "SNESCreate");
118
119 // Set the SNES callbacks to call into our compute_local_function and compute_local_jacobian
120 m_callbackData.da = *m_DA;
121 m_callbackData.solver = this;
122
123 ierr = DMDASNESSetFunctionLocal(*m_DA, INSERT_VALUES,
124 (DMDASNESFunction)SNESProblem<DOF, U>::function_callback,
125 &m_callbackData);
126 PISM_CHK(ierr, "DMDASNESSetFunctionLocal");
127
128 ierr = DMDASNESSetJacobianLocal(*m_DA, (DMDASNESJacobian)SNESProblem<DOF, U>::jacobian_callback,
129 &m_callbackData);
130 PISM_CHK(ierr, "DMDASNESSetJacobianLocal");
131
132 ierr = DMSetMatType(*m_DA, "baij");
133 PISM_CHK(ierr, "DMSetMatType");
134
135 ierr = DMSetApplicationContext(*m_DA, &m_callbackData);
136 PISM_CHK(ierr, "DMSetApplicationContext");
137
138 ierr = SNESSetDM(m_snes, *m_DA);
139 PISM_CHK(ierr, "SNESSetDM");
140
141 ierr = SNESSetFromOptions(m_snes);
142 PISM_CHK(ierr, "SNESSetFromOptions");
143 }
144
145 template<int DOF, class U>
146 SNESProblem<DOF,U>::~SNESProblem() {
147 // empty
148 }
149
150 template<int DOF, class U>
151 const std::string& SNESProblem<DOF,U>::name() {
152 return "UnnamedProblem";
153 }
154
155 template<int DOF, class U>
156 void SNESProblem<DOF,U>::solve() {
157 PetscErrorCode ierr;
158
159 // Solve:
160 ierr = SNESSolve(m_snes,NULL,m_X); PISM_CHK(ierr, "SNESSolve");
161
162 // See if it worked.
163 SNESConvergedReason reason;
164 ierr = SNESGetConvergedReason(m_snes, &reason); PISM_CHK(ierr, "SNESGetConvergedReason");
165 if (reason < 0) {
166 throw RuntimeError::formatted(PISM_ERROR_LOCATION, "SNESProblem %s solve failed to converge (SNES reason %s)",
167 name().c_str(), SNESConvergedReasons[reason]);
168 }
169
170 m_grid->ctx()->log()->message(1, "SNESProblem %s converged (SNES reason %s)\n",
171 name().c_str(), SNESConvergedReasons[reason]);
172 }
173
174 } // end of namespace pism
175
176 #endif