URI:
       tIP_SSATaucTikhonovGNSolver.cc - pism - [fork] customized build of PISM, the parallel ice sheet model (tillflux branch)
  HTML git clone git://src.adamsgaard.dk/pism
   DIR Log
   DIR Files
   DIR Refs
   DIR LICENSE
       ---
       tIP_SSATaucTikhonovGNSolver.cc (16208B)
       ---
            1 // Copyright (C) 2012, 2013, 2014, 2015, 2016, 2017, 2019  David Maxwell and Constantine Khroulev
            2 //
            3 // This file is part of PISM.
            4 //
            5 // PISM is free software; you can redistribute it and/or modify it under the
            6 // terms of the GNU General Public License as published by the Free Software
            7 // Foundation; either version 3 of the License, or (at your option) any later
            8 // version.
            9 //
           10 // PISM is distributed in the hope that it will be useful, but WITHOUT ANY
           11 // WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
           12 // FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
           13 // details.
           14 //
           15 // You should have received a copy of the GNU General Public License
           16 // along with PISM; if not, write to the Free Software
           17 // Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
           18 
           19 #include "IP_SSATaucTikhonovGNSolver.hh"
           20 #include "pism/util/TerminationReason.hh"
           21 #include "pism/util/pism_options.hh"
           22 #include "pism/util/ConfigInterface.hh"
           23 #include "pism/util/IceGrid.hh"
           24 
           25 namespace pism {
           26 namespace inverse {
           27 
           28 IP_SSATaucTikhonovGNSolver::IP_SSATaucTikhonovGNSolver(IP_SSATaucForwardProblem &ssaforward,
           29                                                        DesignVec &d0, StateVec &u_obs, double eta,
           30                                                        IPInnerProductFunctional<DesignVec> &designFunctional,
           31                                                        IPInnerProductFunctional<StateVec> &stateFunctional)
           32   : m_design_stencil_width(d0.stencil_width()),
           33     m_state_stencil_width(u_obs.stencil_width()),
           34     m_ssaforward(ssaforward),
           35     m_x(d0.grid(), "x", WITH_GHOSTS, m_design_stencil_width),
           36     m_tmp_D1Global(d0.grid(), "work vector", WITHOUT_GHOSTS, 0),
           37     m_tmp_D2Global(d0.grid(), "work vector", WITHOUT_GHOSTS, 0),
           38     m_tmp_D1Local(d0.grid(), "work vector", WITH_GHOSTS, m_design_stencil_width),
           39     m_tmp_D2Local(d0.grid(), "work vector", WITH_GHOSTS, m_design_stencil_width),
           40     m_tmp_S1Global(d0.grid(), "work vector", WITHOUT_GHOSTS, 0),
           41     m_tmp_S2Global(d0.grid(), "work vector", WITHOUT_GHOSTS, 0),
           42     m_tmp_S1Local(d0.grid(), "work vector", WITH_GHOSTS, m_state_stencil_width),
           43     m_tmp_S2Local(d0.grid(), "work vector", WITH_GHOSTS, m_state_stencil_width),
           44     m_GN_rhs(d0.grid(), "GN_rhs", WITHOUT_GHOSTS, 0),
           45     m_d0(d0),
           46     m_dGlobal(d0.grid(), "d (sans ghosts)", WITHOUT_GHOSTS, 0),
           47     m_d_diff(d0.grid(), "d_diff", WITH_GHOSTS, m_design_stencil_width),
           48     m_d_diff_lin(d0.grid(), "d_diff linearized", WITH_GHOSTS, m_design_stencil_width),
           49     m_h(d0.grid(), "h", WITH_GHOSTS, m_design_stencil_width),
           50     m_hGlobal(d0.grid(), "h (sans ghosts)", WITHOUT_GHOSTS),
           51     m_dalpha_rhs(d0.grid(), "dalpha rhs", WITHOUT_GHOSTS),
           52     m_dh_dalpha(d0.grid(), "dh_dalpha", WITH_GHOSTS, m_design_stencil_width),
           53     m_dh_dalphaGlobal(d0.grid(), "dh_dalpha", WITHOUT_GHOSTS),
           54     m_grad_design(d0.grid(), "grad design", WITHOUT_GHOSTS),
           55     m_grad_state(d0.grid(), "grad design", WITHOUT_GHOSTS),
           56     m_gradient(d0.grid(), "grad design", WITHOUT_GHOSTS),
           57     m_u_obs(u_obs),
           58     m_u_diff(d0.grid(), "du", WITH_GHOSTS, m_state_stencil_width),
           59     m_eta(eta),
           60     m_designFunctional(designFunctional),
           61     m_stateFunctional(stateFunctional),
           62     m_target_misfit(0.0)
           63 {
           64   PetscErrorCode ierr;
           65   IceGrid::ConstPtr grid = m_d0.grid();
           66   m_comm = grid->com;
           67 
           68   m_d.reset(new DesignVec(grid, "d", WITH_GHOSTS, m_design_stencil_width));
           69 
           70   ierr = KSPCreate(grid->com, m_ksp.rawptr());
           71   PISM_CHK(ierr, "KSPCreate");
           72 
           73   ierr = KSPSetOptionsPrefix(m_ksp, "inv_gn_");
           74   PISM_CHK(ierr, "KSPSetOptionsPrefix");
           75 
           76   double ksp_rtol = 1e-5; // Soft tolerance
           77   ierr = KSPSetTolerances(m_ksp, ksp_rtol, PETSC_DEFAULT, PETSC_DEFAULT, PETSC_DEFAULT);
           78   PISM_CHK(ierr, "KSPSetTolerances");
           79 
           80   ierr = KSPSetType(m_ksp, KSPCG);
           81   PISM_CHK(ierr, "KSPSetType");
           82 
           83   PC pc;
           84   ierr = KSPGetPC(m_ksp, &pc);
           85   PISM_CHK(ierr, "KSPGetPC");
           86 
           87   ierr = PCSetType(pc, PCNONE);
           88   PISM_CHK(ierr, "PCSetType");
           89 
           90   ierr = KSPSetFromOptions(m_ksp);
           91   PISM_CHK(ierr, "KSPSetFromOptions");
           92 
           93   int nLocalNodes  = grid->xm()*grid->ym();
           94   int nGlobalNodes = grid->Mx()*grid->My();
           95   ierr = MatCreateShell(grid->com, nLocalNodes, nLocalNodes,
           96                         nGlobalNodes, nGlobalNodes, this, m_mat_GN.rawptr());
           97   PISM_CHK(ierr, "MatCreateShell");
           98 
           99   typedef MatrixMultiplyCallback<IP_SSATaucTikhonovGNSolver,
          100                                  &IP_SSATaucTikhonovGNSolver::apply_GN> multCallback;
          101   multCallback::connect(m_mat_GN);
          102 
          103   m_alpha = 1./m_eta;
          104   m_logalpha = log(m_alpha);
          105 
          106   m_tikhonov_adaptive = options::Bool("-tikhonov_adaptive", "Tikhonov adaptive");
          107 
          108   m_iter_max = 1000;
          109   m_iter_max = options::Integer("-inv_gn_iter_max", "", m_iter_max);
          110 
          111   m_tikhonov_atol = grid->ctx()->config()->get_number("inverse.tikhonov.atol");
          112   m_tikhonov_rtol = grid->ctx()->config()->get_number("inverse.tikhonov.rtol");
          113   m_tikhonov_ptol = grid->ctx()->config()->get_number("inverse.tikhonov.ptol");
          114 
          115   m_log = d0.grid()->ctx()->log();
          116 }
          117 
          118 IP_SSATaucTikhonovGNSolver::~IP_SSATaucTikhonovGNSolver() {
          119   // empty
          120 }
          121 
          122 TerminationReason::Ptr IP_SSATaucTikhonovGNSolver::init() {
          123   return m_ssaforward.linearize_at(m_d0);
          124 }
          125 
          126 void IP_SSATaucTikhonovGNSolver::apply_GN(IceModelVec2S &x, IceModelVec2S &y) {
          127   this->apply_GN(x.vec(), y.vec());
          128 }
          129 
          130 //! @note This function has to return PetscErrorCode (it is used as a callback).
          131 void  IP_SSATaucTikhonovGNSolver::apply_GN(Vec x, Vec y) {
          132   StateVec  &tmp_gS = m_tmp_S1Global;
          133   StateVec  &Tx     = m_tmp_S1Local;
          134   DesignVec &tmp_gD = m_tmp_D1Global;
          135   DesignVec &GNx    = m_tmp_D2Global;
          136 
          137   // FIXME: Needless copies for now.
          138   m_x.copy_from_vec(x);
          139 
          140   m_ssaforward.apply_linearization(m_x,Tx);
          141   Tx.update_ghosts();
          142 
          143   m_stateFunctional.interior_product(Tx,tmp_gS);
          144 
          145   m_ssaforward.apply_linearization_transpose(tmp_gS,GNx);
          146 
          147   m_designFunctional.interior_product(m_x,tmp_gD);
          148   GNx.add(m_alpha,tmp_gD);
          149 
          150   PetscErrorCode ierr = VecCopy(GNx.vec(), y); PISM_CHK(ierr, "VecCopy");
          151 }
          152 
          153 void IP_SSATaucTikhonovGNSolver::assemble_GN_rhs(DesignVec &rhs) {
          154 
          155   rhs.set(0);
          156   
          157   m_stateFunctional.interior_product(m_u_diff,m_tmp_S1Global);
          158   m_ssaforward.apply_linearization_transpose(m_tmp_S1Global,rhs);
          159 
          160   m_designFunctional.interior_product(m_d_diff,m_tmp_D1Global);
          161   rhs.add(m_alpha,m_tmp_D1Global);
          162   
          163   rhs.scale(-1);
          164 }
          165 
          166 TerminationReason::Ptr IP_SSATaucTikhonovGNSolver::solve_linearized() {
          167   PetscErrorCode ierr;
          168 
          169   this->assemble_GN_rhs(m_GN_rhs);
          170 
          171   ierr = KSPSetOperators(m_ksp,m_mat_GN,m_mat_GN);
          172   PISM_CHK(ierr, "KSPSetOperators");
          173 
          174   ierr = KSPSolve(m_ksp,m_GN_rhs.vec(),m_hGlobal.vec());
          175   PISM_CHK(ierr, "KSPSolve");
          176 
          177   KSPConvergedReason ksp_reason;
          178   ierr = KSPGetConvergedReason(m_ksp ,&ksp_reason);
          179   PISM_CHK(ierr, "KSPGetConvergedReason");
          180   
          181   m_h.copy_from(m_hGlobal);
          182 
          183   return TerminationReason::Ptr(new KSPTerminationReason(ksp_reason));
          184 }
          185 
          186 void IP_SSATaucTikhonovGNSolver::evaluateGNFunctional(DesignVec &h, double *value) {
          187   
          188   m_ssaforward.apply_linearization(h,m_tmp_S1Local);
          189   m_tmp_S1Local.update_ghosts();
          190   m_tmp_S1Local.add(1,m_u_diff);
          191   
          192   double sValue;
          193   m_stateFunctional.valueAt(m_tmp_S1Local,&sValue);
          194   
          195   m_tmp_D1Local.copy_from(m_d_diff);
          196   m_tmp_D1Local.add(1,h);
          197   
          198   double dValue;
          199   m_designFunctional.valueAt(m_tmp_D1Local,&dValue);
          200   
          201   *value = m_alpha*dValue + sValue;
          202 }
          203 
          204 
          205 TerminationReason::Ptr IP_SSATaucTikhonovGNSolver::check_convergence() {
          206 
          207   double designNorm, stateNorm, sumNorm;
          208   double dWeight, sWeight;
          209   dWeight = m_alpha;
          210   sWeight = 1;
          211 
          212   designNorm = m_grad_design.norm(NORM_2);
          213   stateNorm  = m_grad_state.norm(NORM_2);
          214 
          215   designNorm *= dWeight;
          216   stateNorm  *= sWeight;
          217 
          218   sumNorm = m_gradient.norm(NORM_2);
          219 
          220   m_log->message(2,
          221              "----------------------------------------------------------\n");
          222   m_log->message(2,
          223              "IP_SSATaucTikhonovGNSolver Iteration %d: misfit %g; functional %g \n",
          224              m_iter, sqrt(m_val_state)*m_vel_scale, m_value*m_vel_scale*m_vel_scale);
          225   if (m_tikhonov_adaptive) {
          226     m_log->message(2, "alpha %g; log(alpha) %g\n", m_alpha, m_logalpha);
          227   }
          228   double relsum = (sumNorm/std::max(designNorm,stateNorm));
          229   m_log->message(2,
          230              "design norm %g stateNorm %g sum %g; relative difference %g\n",
          231              designNorm, stateNorm, sumNorm, relsum);
          232 
          233   // If we have an adaptive tikhonov parameter, check if we have met
          234   // this constraint first.
          235   if (m_tikhonov_adaptive) {
          236     double disc_ratio = fabs((sqrt(m_val_state)/m_target_misfit) - 1.);
          237     if (disc_ratio > m_tikhonov_ptol) {
          238       return GenericTerminationReason::keep_iterating();
          239     }
          240   }
          241   
          242   if (sumNorm < m_tikhonov_atol) {
          243     return TerminationReason::Ptr(new GenericTerminationReason(1,"TIKHONOV_ATOL"));
          244   }
          245 
          246   if (sumNorm < m_tikhonov_rtol*std::max(designNorm,stateNorm)) {
          247     return TerminationReason::Ptr(new GenericTerminationReason(1,"TIKHONOV_RTOL"));
          248   }
          249 
          250   if (m_iter>m_iter_max) {
          251     return GenericTerminationReason::max_iter();
          252   } else {
          253     return GenericTerminationReason::keep_iterating();
          254   }
          255 }
          256 
          257 TerminationReason::Ptr IP_SSATaucTikhonovGNSolver::evaluate_objective_and_gradient() {
          258 
          259   TerminationReason::Ptr reason = m_ssaforward.linearize_at(*m_d);
          260   if (reason->failed()) {
          261     return reason;
          262   }
          263 
          264   m_d_diff.copy_from(*m_d);
          265   m_d_diff.add(-1,m_d0);
          266 
          267   m_u_diff.copy_from(*m_ssaforward.solution());
          268   m_u_diff.add(-1,m_u_obs);
          269 
          270   m_designFunctional.gradientAt(m_d_diff,m_grad_design);
          271 
          272   // The following computes the reduced gradient.
          273   StateVec &adjointRHS = m_tmp_S1Global;
          274   m_stateFunctional.gradientAt(m_u_diff,adjointRHS);  
          275   m_ssaforward.apply_linearization_transpose(adjointRHS,m_grad_state);
          276 
          277   m_gradient.copy_from(m_grad_design);
          278   m_gradient.scale(m_alpha);    
          279   m_gradient.add(1,m_grad_state);
          280 
          281   double valDesign, valState;
          282   m_designFunctional.valueAt(m_d_diff,&valDesign);
          283   m_stateFunctional.valueAt(m_u_diff,&valState);
          284 
          285   m_val_design = valDesign;
          286   m_val_state = valState;
          287   
          288   m_value = valDesign * m_alpha + valState;
          289 
          290   return reason;
          291 }
          292 
          293 TerminationReason::Ptr IP_SSATaucTikhonovGNSolver::linesearch() {
          294   PetscErrorCode ierr;
          295 
          296   TerminationReason::Ptr step_reason;
          297 
          298   double old_value = m_val_design * m_alpha + m_val_state;
          299 
          300   double descent_derivative;
          301 
          302   m_tmp_D1Global.copy_from(m_h);
          303 
          304   ierr = VecDot(m_gradient.vec(), m_tmp_D1Global.vec(), &descent_derivative);
          305   PISM_CHK(ierr, "VecDot");
          306 
          307   if (descent_derivative >=0) {
          308     printf("descent derivative: %g\n",descent_derivative);
          309     return TerminationReason::Ptr(new GenericTerminationReason(-1, "Not descent direction"));
          310   }
          311 
          312   double alpha = 1;
          313   m_tmp_D1Local.copy_from(*m_d);
          314   while(true) {
          315     m_d->add(alpha,m_h);  // Replace with line search.
          316     step_reason = this->evaluate_objective_and_gradient();
          317     if (step_reason->succeeded()) {
          318       if (m_value <= old_value + 1e-3*alpha*descent_derivative) {
          319         break;
          320       }
          321     }
          322     else {
          323       printf("forward solve failed in linsearch.  Shrinking.\n");
          324     }
          325     alpha *=.5;
          326     if (alpha<1e-20) {
          327       printf("alpha= %g; derivative = %g\n",alpha,descent_derivative);
          328       return TerminationReason::Ptr(new GenericTerminationReason(-1, "Too many step shrinks."));
          329     }
          330     m_d->copy_from(m_tmp_D1Local);
          331   }
          332   
          333   return GenericTerminationReason::success();
          334 }
          335 
          336 TerminationReason::Ptr IP_SSATaucTikhonovGNSolver::solve() {
          337 
          338   if (m_target_misfit == 0) {
          339     throw RuntimeError::formatted(PISM_ERROR_LOCATION, "Call set target misfit prior to calling"
          340                                   " IP_SSATaucTikhonovGNSolver::solve.");
          341   }
          342 
          343   m_iter = 0;
          344   m_d->copy_from(m_d0);
          345 
          346   double dlogalpha = 0;
          347 
          348   TerminationReason::Ptr step_reason, reason;
          349 
          350   step_reason = this->evaluate_objective_and_gradient();
          351   if (step_reason->failed()) {
          352     reason.reset(new GenericTerminationReason(-1,"Forward solve"));
          353     reason->set_root_cause(step_reason);
          354     return reason;
          355   }
          356 
          357   while(true) {
          358 
          359     reason = this->check_convergence();
          360     if (reason->done()) {
          361       return reason;
          362     }
          363 
          364     if (m_tikhonov_adaptive) {
          365       m_logalpha += dlogalpha;
          366       m_alpha = exp(m_logalpha);
          367     }
          368 
          369     step_reason = this->solve_linearized();
          370     if (step_reason->failed()) {
          371       reason.reset(new GenericTerminationReason(-1,"Gauss Newton solve"));
          372       reason->set_root_cause(step_reason);
          373       return reason;
          374     }
          375 
          376     step_reason = this->linesearch();
          377     if (step_reason->failed()) {
          378       TerminationReason::Ptr cause = reason;
          379       reason.reset(new GenericTerminationReason(-1,"Linesearch"));
          380       reason->set_root_cause(step_reason);
          381       return reason;
          382     }
          383 
          384     if (m_tikhonov_adaptive) {
          385       step_reason = this->compute_dlogalpha(&dlogalpha);
          386       if (step_reason->failed()) {
          387         TerminationReason::Ptr cause = reason;
          388         reason.reset(new GenericTerminationReason(-1,"Tikhonov penalty update"));
          389         reason->set_root_cause(step_reason);
          390         return reason;
          391       }
          392     }
          393 
          394     m_iter++;
          395   }
          396 
          397   return reason;
          398 }
          399 
          400 TerminationReason::Ptr IP_SSATaucTikhonovGNSolver::compute_dlogalpha(double *dlogalpha) {
          401 
          402   PetscErrorCode ierr;
          403 
          404   // Compute the right-hand side for computing dh/dalpha.
          405   m_d_diff_lin.copy_from(m_d_diff);
          406   m_d_diff_lin.add(1,m_h);  
          407   m_designFunctional.interior_product(m_d_diff_lin,m_dalpha_rhs);
          408   m_dalpha_rhs.scale(-1);
          409 
          410   // Solve linear equation for dh/dalpha. 
          411   ierr = KSPSetOperators(m_ksp,m_mat_GN,m_mat_GN);
          412   PISM_CHK(ierr, "KSPSetOperators");
          413 
          414   ierr = KSPSolve(m_ksp,m_dalpha_rhs.vec(),m_dh_dalphaGlobal.vec());
          415   PISM_CHK(ierr, "KSPSolve");
          416 
          417   m_dh_dalpha.copy_from(m_dh_dalphaGlobal);
          418 
          419   KSPConvergedReason ksp_reason;
          420   ierr = KSPGetConvergedReason(m_ksp,&ksp_reason);
          421   PISM_CHK(ierr, "KSPGetConvergedReason");
          422 
          423   if (ksp_reason<0) {
          424     return TerminationReason::Ptr(new KSPTerminationReason(ksp_reason));
          425   }
          426 
          427   // S1Local contains T(h) + F(x) - u_obs, i.e. the linearized misfit field.
          428   m_ssaforward.apply_linearization(m_h,m_tmp_S1Local);
          429   m_tmp_S1Local.update_ghosts();
          430   m_tmp_S1Local.add(1,m_u_diff);
          431 
          432   // Compute linearized discrepancy.
          433   double disc_sq;
          434   m_stateFunctional.dot(m_tmp_S1Local,m_tmp_S1Local,&disc_sq);
          435 
          436   // There are a number of equivalent ways to compute the derivative of the 
          437   // linearized discrepancy with respect to alpha, some of which are cheaper
          438   // than others to compute.  This equivalency relies, however, on having an 
          439   // exact solution in the Gauss-Newton step.  Since we only solve this with 
          440   // a soft tolerance, we lose equivalency.  We attempt a cheap computation,
          441   // and then do a sanity check (namely that the derivative is positive).
          442   // If this fails, we compute by a harder way that inherently yields a 
          443   // positive number.
          444 
          445   double ddisc_sq_dalpha;
          446   m_designFunctional.dot(m_dh_dalpha,m_d_diff_lin,&ddisc_sq_dalpha);
          447   ddisc_sq_dalpha *= -2*m_alpha;
          448 
          449   if (ddisc_sq_dalpha <= 0) {
          450     // Try harder.
          451     
          452     m_log->message(3,
          453                "Adaptive Tikhonov sanity check failed (dh/dalpha= %g <= 0)."
          454                " Tighten inv_gn_ksp_rtol?\n",
          455                ddisc_sq_dalpha);
          456     
          457     // S2Local contains T(dh/dalpha)
          458     m_ssaforward.apply_linearization(m_dh_dalpha,m_tmp_S2Local);
          459     m_tmp_S2Local.update_ghosts();
          460 
          461     double ddisc_sq_dalpha_a;
          462     m_stateFunctional.dot(m_tmp_S2Local,m_tmp_S2Local,&ddisc_sq_dalpha_a);
          463     double ddisc_sq_dalpha_b;
          464     m_designFunctional.dot(m_dh_dalpha,m_dh_dalpha,&ddisc_sq_dalpha_b);
          465     ddisc_sq_dalpha = 2*m_alpha*(ddisc_sq_dalpha_a+m_alpha*ddisc_sq_dalpha_b);
          466 
          467     m_log->message(3,
          468                "Adaptive Tikhonov sanity check recovery attempt: dh/dalpha= %g. \n",
          469                ddisc_sq_dalpha);
          470 
          471     // This is yet another alternative formula.
          472     // m_stateFunctional.dot(m_tmp_S1Local,m_tmp_S2Local,&ddisc_sq_dalpha);
          473     // ddisc_sq_dalpha *= 2;
          474   }
          475 
          476   // Newton's method formula.
          477   *dlogalpha = (m_target_misfit*m_target_misfit-disc_sq)/(ddisc_sq_dalpha*m_alpha);
          478 
          479   // It's easy to take steps that are too big when we are far from the solution.
          480   // So we limit the step size.
          481   double stepmax = 3;
          482   if (fabs(*dlogalpha)> stepmax) {
          483     double sgn = *dlogalpha > 0 ? 1 : -1;
          484     *dlogalpha = stepmax*sgn;
          485   }
          486   
          487   if (*dlogalpha<0) {
          488     *dlogalpha*=.5;
          489   }
          490 
          491   return GenericTerminationReason::success();
          492 }
          493 
          494 } // end of namespace inverse
          495 } // end of namespace pism