tIP_SSATaucTikhonovGNSolver.cc - pism - [fork] customized build of PISM, the parallel ice sheet model (tillflux branch)
HTML git clone git://src.adamsgaard.dk/pism
DIR Log
DIR Files
DIR Refs
DIR LICENSE
---
tIP_SSATaucTikhonovGNSolver.cc (16208B)
---
1 // Copyright (C) 2012, 2013, 2014, 2015, 2016, 2017, 2019 David Maxwell and Constantine Khroulev
2 //
3 // This file is part of PISM.
4 //
5 // PISM is free software; you can redistribute it and/or modify it under the
6 // terms of the GNU General Public License as published by the Free Software
7 // Foundation; either version 3 of the License, or (at your option) any later
8 // version.
9 //
10 // PISM is distributed in the hope that it will be useful, but WITHOUT ANY
11 // WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
12 // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
13 // details.
14 //
15 // You should have received a copy of the GNU General Public License
16 // along with PISM; if not, write to the Free Software
17 // Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
18
19 #include "IP_SSATaucTikhonovGNSolver.hh"
20 #include "pism/util/TerminationReason.hh"
21 #include "pism/util/pism_options.hh"
22 #include "pism/util/ConfigInterface.hh"
23 #include "pism/util/IceGrid.hh"
24
25 namespace pism {
26 namespace inverse {
27
28 IP_SSATaucTikhonovGNSolver::IP_SSATaucTikhonovGNSolver(IP_SSATaucForwardProblem &ssaforward,
29 DesignVec &d0, StateVec &u_obs, double eta,
30 IPInnerProductFunctional<DesignVec> &designFunctional,
31 IPInnerProductFunctional<StateVec> &stateFunctional)
32 : m_design_stencil_width(d0.stencil_width()),
33 m_state_stencil_width(u_obs.stencil_width()),
34 m_ssaforward(ssaforward),
35 m_x(d0.grid(), "x", WITH_GHOSTS, m_design_stencil_width),
36 m_tmp_D1Global(d0.grid(), "work vector", WITHOUT_GHOSTS, 0),
37 m_tmp_D2Global(d0.grid(), "work vector", WITHOUT_GHOSTS, 0),
38 m_tmp_D1Local(d0.grid(), "work vector", WITH_GHOSTS, m_design_stencil_width),
39 m_tmp_D2Local(d0.grid(), "work vector", WITH_GHOSTS, m_design_stencil_width),
40 m_tmp_S1Global(d0.grid(), "work vector", WITHOUT_GHOSTS, 0),
41 m_tmp_S2Global(d0.grid(), "work vector", WITHOUT_GHOSTS, 0),
42 m_tmp_S1Local(d0.grid(), "work vector", WITH_GHOSTS, m_state_stencil_width),
43 m_tmp_S2Local(d0.grid(), "work vector", WITH_GHOSTS, m_state_stencil_width),
44 m_GN_rhs(d0.grid(), "GN_rhs", WITHOUT_GHOSTS, 0),
45 m_d0(d0),
46 m_dGlobal(d0.grid(), "d (sans ghosts)", WITHOUT_GHOSTS, 0),
47 m_d_diff(d0.grid(), "d_diff", WITH_GHOSTS, m_design_stencil_width),
48 m_d_diff_lin(d0.grid(), "d_diff linearized", WITH_GHOSTS, m_design_stencil_width),
49 m_h(d0.grid(), "h", WITH_GHOSTS, m_design_stencil_width),
50 m_hGlobal(d0.grid(), "h (sans ghosts)", WITHOUT_GHOSTS),
51 m_dalpha_rhs(d0.grid(), "dalpha rhs", WITHOUT_GHOSTS),
52 m_dh_dalpha(d0.grid(), "dh_dalpha", WITH_GHOSTS, m_design_stencil_width),
53 m_dh_dalphaGlobal(d0.grid(), "dh_dalpha", WITHOUT_GHOSTS),
54 m_grad_design(d0.grid(), "grad design", WITHOUT_GHOSTS),
55 m_grad_state(d0.grid(), "grad design", WITHOUT_GHOSTS),
56 m_gradient(d0.grid(), "grad design", WITHOUT_GHOSTS),
57 m_u_obs(u_obs),
58 m_u_diff(d0.grid(), "du", WITH_GHOSTS, m_state_stencil_width),
59 m_eta(eta),
60 m_designFunctional(designFunctional),
61 m_stateFunctional(stateFunctional),
62 m_target_misfit(0.0)
63 {
64 PetscErrorCode ierr;
65 IceGrid::ConstPtr grid = m_d0.grid();
66 m_comm = grid->com;
67
68 m_d.reset(new DesignVec(grid, "d", WITH_GHOSTS, m_design_stencil_width));
69
70 ierr = KSPCreate(grid->com, m_ksp.rawptr());
71 PISM_CHK(ierr, "KSPCreate");
72
73 ierr = KSPSetOptionsPrefix(m_ksp, "inv_gn_");
74 PISM_CHK(ierr, "KSPSetOptionsPrefix");
75
76 double ksp_rtol = 1e-5; // Soft tolerance
77 ierr = KSPSetTolerances(m_ksp, ksp_rtol, PETSC_DEFAULT, PETSC_DEFAULT, PETSC_DEFAULT);
78 PISM_CHK(ierr, "KSPSetTolerances");
79
80 ierr = KSPSetType(m_ksp, KSPCG);
81 PISM_CHK(ierr, "KSPSetType");
82
83 PC pc;
84 ierr = KSPGetPC(m_ksp, &pc);
85 PISM_CHK(ierr, "KSPGetPC");
86
87 ierr = PCSetType(pc, PCNONE);
88 PISM_CHK(ierr, "PCSetType");
89
90 ierr = KSPSetFromOptions(m_ksp);
91 PISM_CHK(ierr, "KSPSetFromOptions");
92
93 int nLocalNodes = grid->xm()*grid->ym();
94 int nGlobalNodes = grid->Mx()*grid->My();
95 ierr = MatCreateShell(grid->com, nLocalNodes, nLocalNodes,
96 nGlobalNodes, nGlobalNodes, this, m_mat_GN.rawptr());
97 PISM_CHK(ierr, "MatCreateShell");
98
99 typedef MatrixMultiplyCallback<IP_SSATaucTikhonovGNSolver,
100 &IP_SSATaucTikhonovGNSolver::apply_GN> multCallback;
101 multCallback::connect(m_mat_GN);
102
103 m_alpha = 1./m_eta;
104 m_logalpha = log(m_alpha);
105
106 m_tikhonov_adaptive = options::Bool("-tikhonov_adaptive", "Tikhonov adaptive");
107
108 m_iter_max = 1000;
109 m_iter_max = options::Integer("-inv_gn_iter_max", "", m_iter_max);
110
111 m_tikhonov_atol = grid->ctx()->config()->get_number("inverse.tikhonov.atol");
112 m_tikhonov_rtol = grid->ctx()->config()->get_number("inverse.tikhonov.rtol");
113 m_tikhonov_ptol = grid->ctx()->config()->get_number("inverse.tikhonov.ptol");
114
115 m_log = d0.grid()->ctx()->log();
116 }
117
118 IP_SSATaucTikhonovGNSolver::~IP_SSATaucTikhonovGNSolver() {
119 // empty
120 }
121
122 TerminationReason::Ptr IP_SSATaucTikhonovGNSolver::init() {
123 return m_ssaforward.linearize_at(m_d0);
124 }
125
126 void IP_SSATaucTikhonovGNSolver::apply_GN(IceModelVec2S &x, IceModelVec2S &y) {
127 this->apply_GN(x.vec(), y.vec());
128 }
129
130 //! @note This function has to return PetscErrorCode (it is used as a callback).
131 void IP_SSATaucTikhonovGNSolver::apply_GN(Vec x, Vec y) {
132 StateVec &tmp_gS = m_tmp_S1Global;
133 StateVec &Tx = m_tmp_S1Local;
134 DesignVec &tmp_gD = m_tmp_D1Global;
135 DesignVec &GNx = m_tmp_D2Global;
136
137 // FIXME: Needless copies for now.
138 m_x.copy_from_vec(x);
139
140 m_ssaforward.apply_linearization(m_x,Tx);
141 Tx.update_ghosts();
142
143 m_stateFunctional.interior_product(Tx,tmp_gS);
144
145 m_ssaforward.apply_linearization_transpose(tmp_gS,GNx);
146
147 m_designFunctional.interior_product(m_x,tmp_gD);
148 GNx.add(m_alpha,tmp_gD);
149
150 PetscErrorCode ierr = VecCopy(GNx.vec(), y); PISM_CHK(ierr, "VecCopy");
151 }
152
153 void IP_SSATaucTikhonovGNSolver::assemble_GN_rhs(DesignVec &rhs) {
154
155 rhs.set(0);
156
157 m_stateFunctional.interior_product(m_u_diff,m_tmp_S1Global);
158 m_ssaforward.apply_linearization_transpose(m_tmp_S1Global,rhs);
159
160 m_designFunctional.interior_product(m_d_diff,m_tmp_D1Global);
161 rhs.add(m_alpha,m_tmp_D1Global);
162
163 rhs.scale(-1);
164 }
165
166 TerminationReason::Ptr IP_SSATaucTikhonovGNSolver::solve_linearized() {
167 PetscErrorCode ierr;
168
169 this->assemble_GN_rhs(m_GN_rhs);
170
171 ierr = KSPSetOperators(m_ksp,m_mat_GN,m_mat_GN);
172 PISM_CHK(ierr, "KSPSetOperators");
173
174 ierr = KSPSolve(m_ksp,m_GN_rhs.vec(),m_hGlobal.vec());
175 PISM_CHK(ierr, "KSPSolve");
176
177 KSPConvergedReason ksp_reason;
178 ierr = KSPGetConvergedReason(m_ksp ,&ksp_reason);
179 PISM_CHK(ierr, "KSPGetConvergedReason");
180
181 m_h.copy_from(m_hGlobal);
182
183 return TerminationReason::Ptr(new KSPTerminationReason(ksp_reason));
184 }
185
186 void IP_SSATaucTikhonovGNSolver::evaluateGNFunctional(DesignVec &h, double *value) {
187
188 m_ssaforward.apply_linearization(h,m_tmp_S1Local);
189 m_tmp_S1Local.update_ghosts();
190 m_tmp_S1Local.add(1,m_u_diff);
191
192 double sValue;
193 m_stateFunctional.valueAt(m_tmp_S1Local,&sValue);
194
195 m_tmp_D1Local.copy_from(m_d_diff);
196 m_tmp_D1Local.add(1,h);
197
198 double dValue;
199 m_designFunctional.valueAt(m_tmp_D1Local,&dValue);
200
201 *value = m_alpha*dValue + sValue;
202 }
203
204
205 TerminationReason::Ptr IP_SSATaucTikhonovGNSolver::check_convergence() {
206
207 double designNorm, stateNorm, sumNorm;
208 double dWeight, sWeight;
209 dWeight = m_alpha;
210 sWeight = 1;
211
212 designNorm = m_grad_design.norm(NORM_2);
213 stateNorm = m_grad_state.norm(NORM_2);
214
215 designNorm *= dWeight;
216 stateNorm *= sWeight;
217
218 sumNorm = m_gradient.norm(NORM_2);
219
220 m_log->message(2,
221 "----------------------------------------------------------\n");
222 m_log->message(2,
223 "IP_SSATaucTikhonovGNSolver Iteration %d: misfit %g; functional %g \n",
224 m_iter, sqrt(m_val_state)*m_vel_scale, m_value*m_vel_scale*m_vel_scale);
225 if (m_tikhonov_adaptive) {
226 m_log->message(2, "alpha %g; log(alpha) %g\n", m_alpha, m_logalpha);
227 }
228 double relsum = (sumNorm/std::max(designNorm,stateNorm));
229 m_log->message(2,
230 "design norm %g stateNorm %g sum %g; relative difference %g\n",
231 designNorm, stateNorm, sumNorm, relsum);
232
233 // If we have an adaptive tikhonov parameter, check if we have met
234 // this constraint first.
235 if (m_tikhonov_adaptive) {
236 double disc_ratio = fabs((sqrt(m_val_state)/m_target_misfit) - 1.);
237 if (disc_ratio > m_tikhonov_ptol) {
238 return GenericTerminationReason::keep_iterating();
239 }
240 }
241
242 if (sumNorm < m_tikhonov_atol) {
243 return TerminationReason::Ptr(new GenericTerminationReason(1,"TIKHONOV_ATOL"));
244 }
245
246 if (sumNorm < m_tikhonov_rtol*std::max(designNorm,stateNorm)) {
247 return TerminationReason::Ptr(new GenericTerminationReason(1,"TIKHONOV_RTOL"));
248 }
249
250 if (m_iter>m_iter_max) {
251 return GenericTerminationReason::max_iter();
252 } else {
253 return GenericTerminationReason::keep_iterating();
254 }
255 }
256
257 TerminationReason::Ptr IP_SSATaucTikhonovGNSolver::evaluate_objective_and_gradient() {
258
259 TerminationReason::Ptr reason = m_ssaforward.linearize_at(*m_d);
260 if (reason->failed()) {
261 return reason;
262 }
263
264 m_d_diff.copy_from(*m_d);
265 m_d_diff.add(-1,m_d0);
266
267 m_u_diff.copy_from(*m_ssaforward.solution());
268 m_u_diff.add(-1,m_u_obs);
269
270 m_designFunctional.gradientAt(m_d_diff,m_grad_design);
271
272 // The following computes the reduced gradient.
273 StateVec &adjointRHS = m_tmp_S1Global;
274 m_stateFunctional.gradientAt(m_u_diff,adjointRHS);
275 m_ssaforward.apply_linearization_transpose(adjointRHS,m_grad_state);
276
277 m_gradient.copy_from(m_grad_design);
278 m_gradient.scale(m_alpha);
279 m_gradient.add(1,m_grad_state);
280
281 double valDesign, valState;
282 m_designFunctional.valueAt(m_d_diff,&valDesign);
283 m_stateFunctional.valueAt(m_u_diff,&valState);
284
285 m_val_design = valDesign;
286 m_val_state = valState;
287
288 m_value = valDesign * m_alpha + valState;
289
290 return reason;
291 }
292
293 TerminationReason::Ptr IP_SSATaucTikhonovGNSolver::linesearch() {
294 PetscErrorCode ierr;
295
296 TerminationReason::Ptr step_reason;
297
298 double old_value = m_val_design * m_alpha + m_val_state;
299
300 double descent_derivative;
301
302 m_tmp_D1Global.copy_from(m_h);
303
304 ierr = VecDot(m_gradient.vec(), m_tmp_D1Global.vec(), &descent_derivative);
305 PISM_CHK(ierr, "VecDot");
306
307 if (descent_derivative >=0) {
308 printf("descent derivative: %g\n",descent_derivative);
309 return TerminationReason::Ptr(new GenericTerminationReason(-1, "Not descent direction"));
310 }
311
312 double alpha = 1;
313 m_tmp_D1Local.copy_from(*m_d);
314 while(true) {
315 m_d->add(alpha,m_h); // Replace with line search.
316 step_reason = this->evaluate_objective_and_gradient();
317 if (step_reason->succeeded()) {
318 if (m_value <= old_value + 1e-3*alpha*descent_derivative) {
319 break;
320 }
321 }
322 else {
323 printf("forward solve failed in linsearch. Shrinking.\n");
324 }
325 alpha *=.5;
326 if (alpha<1e-20) {
327 printf("alpha= %g; derivative = %g\n",alpha,descent_derivative);
328 return TerminationReason::Ptr(new GenericTerminationReason(-1, "Too many step shrinks."));
329 }
330 m_d->copy_from(m_tmp_D1Local);
331 }
332
333 return GenericTerminationReason::success();
334 }
335
336 TerminationReason::Ptr IP_SSATaucTikhonovGNSolver::solve() {
337
338 if (m_target_misfit == 0) {
339 throw RuntimeError::formatted(PISM_ERROR_LOCATION, "Call set target misfit prior to calling"
340 " IP_SSATaucTikhonovGNSolver::solve.");
341 }
342
343 m_iter = 0;
344 m_d->copy_from(m_d0);
345
346 double dlogalpha = 0;
347
348 TerminationReason::Ptr step_reason, reason;
349
350 step_reason = this->evaluate_objective_and_gradient();
351 if (step_reason->failed()) {
352 reason.reset(new GenericTerminationReason(-1,"Forward solve"));
353 reason->set_root_cause(step_reason);
354 return reason;
355 }
356
357 while(true) {
358
359 reason = this->check_convergence();
360 if (reason->done()) {
361 return reason;
362 }
363
364 if (m_tikhonov_adaptive) {
365 m_logalpha += dlogalpha;
366 m_alpha = exp(m_logalpha);
367 }
368
369 step_reason = this->solve_linearized();
370 if (step_reason->failed()) {
371 reason.reset(new GenericTerminationReason(-1,"Gauss Newton solve"));
372 reason->set_root_cause(step_reason);
373 return reason;
374 }
375
376 step_reason = this->linesearch();
377 if (step_reason->failed()) {
378 TerminationReason::Ptr cause = reason;
379 reason.reset(new GenericTerminationReason(-1,"Linesearch"));
380 reason->set_root_cause(step_reason);
381 return reason;
382 }
383
384 if (m_tikhonov_adaptive) {
385 step_reason = this->compute_dlogalpha(&dlogalpha);
386 if (step_reason->failed()) {
387 TerminationReason::Ptr cause = reason;
388 reason.reset(new GenericTerminationReason(-1,"Tikhonov penalty update"));
389 reason->set_root_cause(step_reason);
390 return reason;
391 }
392 }
393
394 m_iter++;
395 }
396
397 return reason;
398 }
399
400 TerminationReason::Ptr IP_SSATaucTikhonovGNSolver::compute_dlogalpha(double *dlogalpha) {
401
402 PetscErrorCode ierr;
403
404 // Compute the right-hand side for computing dh/dalpha.
405 m_d_diff_lin.copy_from(m_d_diff);
406 m_d_diff_lin.add(1,m_h);
407 m_designFunctional.interior_product(m_d_diff_lin,m_dalpha_rhs);
408 m_dalpha_rhs.scale(-1);
409
410 // Solve linear equation for dh/dalpha.
411 ierr = KSPSetOperators(m_ksp,m_mat_GN,m_mat_GN);
412 PISM_CHK(ierr, "KSPSetOperators");
413
414 ierr = KSPSolve(m_ksp,m_dalpha_rhs.vec(),m_dh_dalphaGlobal.vec());
415 PISM_CHK(ierr, "KSPSolve");
416
417 m_dh_dalpha.copy_from(m_dh_dalphaGlobal);
418
419 KSPConvergedReason ksp_reason;
420 ierr = KSPGetConvergedReason(m_ksp,&ksp_reason);
421 PISM_CHK(ierr, "KSPGetConvergedReason");
422
423 if (ksp_reason<0) {
424 return TerminationReason::Ptr(new KSPTerminationReason(ksp_reason));
425 }
426
427 // S1Local contains T(h) + F(x) - u_obs, i.e. the linearized misfit field.
428 m_ssaforward.apply_linearization(m_h,m_tmp_S1Local);
429 m_tmp_S1Local.update_ghosts();
430 m_tmp_S1Local.add(1,m_u_diff);
431
432 // Compute linearized discrepancy.
433 double disc_sq;
434 m_stateFunctional.dot(m_tmp_S1Local,m_tmp_S1Local,&disc_sq);
435
436 // There are a number of equivalent ways to compute the derivative of the
437 // linearized discrepancy with respect to alpha, some of which are cheaper
438 // than others to compute. This equivalency relies, however, on having an
439 // exact solution in the Gauss-Newton step. Since we only solve this with
440 // a soft tolerance, we lose equivalency. We attempt a cheap computation,
441 // and then do a sanity check (namely that the derivative is positive).
442 // If this fails, we compute by a harder way that inherently yields a
443 // positive number.
444
445 double ddisc_sq_dalpha;
446 m_designFunctional.dot(m_dh_dalpha,m_d_diff_lin,&ddisc_sq_dalpha);
447 ddisc_sq_dalpha *= -2*m_alpha;
448
449 if (ddisc_sq_dalpha <= 0) {
450 // Try harder.
451
452 m_log->message(3,
453 "Adaptive Tikhonov sanity check failed (dh/dalpha= %g <= 0)."
454 " Tighten inv_gn_ksp_rtol?\n",
455 ddisc_sq_dalpha);
456
457 // S2Local contains T(dh/dalpha)
458 m_ssaforward.apply_linearization(m_dh_dalpha,m_tmp_S2Local);
459 m_tmp_S2Local.update_ghosts();
460
461 double ddisc_sq_dalpha_a;
462 m_stateFunctional.dot(m_tmp_S2Local,m_tmp_S2Local,&ddisc_sq_dalpha_a);
463 double ddisc_sq_dalpha_b;
464 m_designFunctional.dot(m_dh_dalpha,m_dh_dalpha,&ddisc_sq_dalpha_b);
465 ddisc_sq_dalpha = 2*m_alpha*(ddisc_sq_dalpha_a+m_alpha*ddisc_sq_dalpha_b);
466
467 m_log->message(3,
468 "Adaptive Tikhonov sanity check recovery attempt: dh/dalpha= %g. \n",
469 ddisc_sq_dalpha);
470
471 // This is yet another alternative formula.
472 // m_stateFunctional.dot(m_tmp_S1Local,m_tmp_S2Local,&ddisc_sq_dalpha);
473 // ddisc_sq_dalpha *= 2;
474 }
475
476 // Newton's method formula.
477 *dlogalpha = (m_target_misfit*m_target_misfit-disc_sq)/(ddisc_sq_dalpha*m_alpha);
478
479 // It's easy to take steps that are too big when we are far from the solution.
480 // So we limit the step size.
481 double stepmax = 3;
482 if (fabs(*dlogalpha)> stepmax) {
483 double sgn = *dlogalpha > 0 ? 1 : -1;
484 *dlogalpha = stepmax*sgn;
485 }
486
487 if (*dlogalpha<0) {
488 *dlogalpha*=.5;
489 }
490
491 return GenericTerminationReason::success();
492 }
493
494 } // end of namespace inverse
495 } // end of namespace pism