URI:
       tSNESProblem.hh - pism - [fork] customized build of PISM, the parallel ice sheet model (tillflux branch)
  HTML git clone git://src.adamsgaard.dk/pism
   DIR Log
   DIR Files
   DIR Refs
   DIR LICENSE
       ---
       tSNESProblem.hh (5414B)
       ---
            1 // Copyright (C) 2011, 2012, 2014, 2015, 2016, 2017 David Maxwell
            2 //
            3 // This file is part of PISM.
            4 //
            5 // PISM is free software; you can redistribute it and/or modify it under the
            6 // terms of the GNU General Public License as published by the Free Software
            7 // Foundation; either version 3 of the License, or (at your option) any later
            8 // version.
            9 //
           10 // PISM is distributed in the hope that it will be useful, but WITHOUT ANY
           11 // WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
           12 // FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
           13 // details.
           14 //
           15 // You should have received a copy of the GNU General Public License
           16 // along with PISM; if not, write to the Free Software
           17 // Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
           18 
           19 
           20 #ifndef _SNESPROBLEM_H_
           21 #define _SNESPROBLEM_H_
           22 
           23 #include "pism/util/IceGrid.hh" // inline implementation in the header uses IceGrid
           24 #include "pism/util/Vector2.hh" // to get Vector2
           25 #include "pism/util/petscwrappers/SNES.hh"
           26 #include "pism/util/Logger.hh"
           27 
           28 namespace pism {
           29 
           30 template<int DOF, class U> class SNESProblem {
           31 public:
           32   SNESProblem(IceGrid::ConstPtr g);
           33 
           34   virtual ~SNESProblem();
           35 
           36   virtual void solve();
           37 
           38   virtual const std::string& name();
           39 
           40   virtual Vec solution()
           41   {
           42     return m_X;
           43   }
           44 
           45 protected:
           46 
           47   virtual void compute_local_function(DMDALocalInfo *info, const U **xg, U **yg) = 0;
           48   virtual void compute_local_jacobian(DMDALocalInfo *info, const U **x,  Mat B) = 0;
           49 
           50   IceGrid::ConstPtr m_grid;
           51 
           52   petsc::Vec m_X;
           53   petsc::SNES m_snes;
           54   petsc::DM::Ptr m_DA;
           55 
           56 private:
           57 
           58   struct CallbackData {
           59     DM da;
           60     SNESProblem<DOF,U> *solver;
           61   };
           62 
           63   CallbackData m_callbackData;
           64 
           65   static PetscErrorCode function_callback(DMDALocalInfo *info, const U **x, U **f,
           66                                           CallbackData *);
           67   static PetscErrorCode jacobian_callback(DMDALocalInfo *info, const U **x, Mat B,
           68                                           CallbackData *);
           69 };
           70 
           71 typedef SNESProblem<1,double> SNESScalarProblem;
           72 typedef SNESProblem<2,Vector2> SNESVectorProblem;
           73 
           74 template<int DOF, class U>
           75 PetscErrorCode SNESProblem<DOF,U>::function_callback(DMDALocalInfo *info,
           76                                                      const U **x, U **f,
           77                                                      SNESProblem<DOF,U>::CallbackData *cb) {
           78   try {
           79     cb->solver->compute_local_function(info,x,f);
           80   } catch (...) {
           81     MPI_Comm com = MPI_COMM_SELF;
           82     PetscErrorCode ierr = PetscObjectGetComm((PetscObject)cb->da, &com); CHKERRQ(ierr);
           83     handle_fatal_errors(com);
           84     SETERRQ(com, 1, "A PISM callback failed");
           85   }
           86   return 0;
           87 }
           88 
           89 template<int DOF, class U>
           90 PetscErrorCode SNESProblem<DOF,U>::jacobian_callback(DMDALocalInfo *info,
           91                                                      const U **x, Mat J,
           92                                                      SNESProblem<DOF,U>::CallbackData *cb) {
           93   try {
           94     cb->solver->compute_local_jacobian(info, x, J);
           95   } catch (...) {
           96     MPI_Comm com = MPI_COMM_SELF;
           97     PetscErrorCode ierr = PetscObjectGetComm((PetscObject)cb->da, &com); CHKERRQ(ierr);
           98     handle_fatal_errors(com);
           99     SETERRQ(com, 1, "A PISM callback failed");
          100   }
          101   return 0;
          102 }
          103 
          104 template<int DOF, class U>
          105 SNESProblem<DOF, U>::SNESProblem(IceGrid::ConstPtr g)
          106   : m_grid(g) {
          107 
          108   PetscErrorCode ierr;
          109 
          110   int stencil_width=1;
          111   m_DA = m_grid->get_dm(DOF, stencil_width);
          112 
          113   ierr = DMCreateGlobalVector(*m_DA, m_X.rawptr());
          114   PISM_CHK(ierr, "DMCreateGlobalVector");
          115 
          116   ierr = SNESCreate(m_grid->com, m_snes.rawptr());
          117   PISM_CHK(ierr, "SNESCreate");
          118 
          119   // Set the SNES callbacks to call into our compute_local_function and compute_local_jacobian
          120   m_callbackData.da = *m_DA;
          121   m_callbackData.solver = this;
          122 
          123   ierr = DMDASNESSetFunctionLocal(*m_DA, INSERT_VALUES,
          124                                   (DMDASNESFunction)SNESProblem<DOF, U>::function_callback,
          125                                   &m_callbackData);
          126   PISM_CHK(ierr, "DMDASNESSetFunctionLocal");
          127 
          128   ierr = DMDASNESSetJacobianLocal(*m_DA, (DMDASNESJacobian)SNESProblem<DOF, U>::jacobian_callback,
          129                                   &m_callbackData);
          130   PISM_CHK(ierr, "DMDASNESSetJacobianLocal");
          131 
          132   ierr = DMSetMatType(*m_DA, "baij");
          133   PISM_CHK(ierr, "DMSetMatType");
          134 
          135   ierr = DMSetApplicationContext(*m_DA, &m_callbackData);
          136   PISM_CHK(ierr, "DMSetApplicationContext");
          137 
          138   ierr = SNESSetDM(m_snes, *m_DA);
          139   PISM_CHK(ierr, "SNESSetDM");
          140 
          141   ierr = SNESSetFromOptions(m_snes);
          142   PISM_CHK(ierr, "SNESSetFromOptions");
          143 }
          144 
          145 template<int DOF, class U>
          146 SNESProblem<DOF,U>::~SNESProblem() {
          147   // empty
          148 }
          149 
          150 template<int DOF, class U>
          151 const std::string& SNESProblem<DOF,U>::name() {
          152   return "UnnamedProblem";
          153 }
          154 
          155 template<int DOF, class U>
          156 void SNESProblem<DOF,U>::solve() {
          157   PetscErrorCode ierr;
          158 
          159   // Solve:
          160   ierr = SNESSolve(m_snes,NULL,m_X); PISM_CHK(ierr, "SNESSolve");
          161 
          162   // See if it worked.
          163   SNESConvergedReason reason;
          164   ierr = SNESGetConvergedReason(m_snes, &reason); PISM_CHK(ierr, "SNESGetConvergedReason");
          165   if (reason < 0) {
          166     throw RuntimeError::formatted(PISM_ERROR_LOCATION, "SNESProblem %s solve failed to converge (SNES reason %s)",
          167                                   name().c_str(), SNESConvergedReasons[reason]);
          168   }
          169 
          170   m_grid->ctx()->log()->message(1, "SNESProblem %s converged (SNES reason %s)\n",
          171                                name().c_str(), SNESConvergedReasons[reason]);
          172 }
          173 
          174 } // end of namespace pism
          175 
          176 #endif