URI:
       tssh-agent.c - plan9port - [fork] Plan 9 from user space
  HTML git clone git://src.adamsgaard.dk/plan9port
   DIR Log
   DIR Files
   DIR Refs
   DIR README
   DIR LICENSE
       ---
       tssh-agent.c (18364B)
       ---
            1 /*
            2  * Present factotum in ssh agent clothing.
            3  */
            4 #include <u.h>
            5 #include <libc.h>
            6 #include <mp.h>
            7 #include <libsec.h>
            8 #include <auth.h>
            9 #include <thread.h>
           10 #include <9pclient.h>
           11 
           12 enum
           13 {
           14         STACK = 65536
           15 };
           16 enum                /* agent protocol packet types */
           17 {
           18         SSH_AGENTC_NONE = 0,
           19         SSH_AGENTC_REQUEST_RSA_IDENTITIES,
           20         SSH_AGENT_RSA_IDENTITIES_ANSWER,
           21         SSH_AGENTC_RSA_CHALLENGE,
           22         SSH_AGENT_RSA_RESPONSE,
           23         SSH_AGENT_FAILURE,
           24         SSH_AGENT_SUCCESS,
           25         SSH_AGENTC_ADD_RSA_IDENTITY,
           26         SSH_AGENTC_REMOVE_RSA_IDENTITY,
           27         SSH_AGENTC_REMOVE_ALL_RSA_IDENTITIES,
           28 
           29         SSH2_AGENTC_REQUEST_IDENTITIES = 11,
           30         SSH2_AGENT_IDENTITIES_ANSWER,
           31         SSH2_AGENTC_SIGN_REQUEST,
           32         SSH2_AGENT_SIGN_RESPONSE,
           33 
           34         SSH2_AGENTC_ADD_IDENTITY = 17,
           35         SSH2_AGENTC_REMOVE_IDENTITY,
           36         SSH2_AGENTC_REMOVE_ALL_IDENTITIES,
           37         SSH2_AGENTC_ADD_SMARTCARD_KEY,
           38         SSH2_AGENTC_REMOVE_SMARTCARD_KEY,
           39 
           40         SSH_AGENTC_LOCK,
           41         SSH_AGENTC_UNLOCK,
           42         SSH_AGENTC_ADD_RSA_ID_CONSTRAINED,
           43         SSH2_AGENTC_ADD_ID_CONSTRAINED,
           44         SSH_AGENTC_ADD_SMARTCARD_KEY_CONSTRAINED,
           45 
           46         SSH_AGENT_CONSTRAIN_LIFETIME = 1,
           47         SSH_AGENT_CONSTRAIN_CONFIRM = 2,
           48 
           49         SSH2_AGENT_FAILURE = 30,
           50 
           51         SSH_COM_AGENT2_FAILURE = 102,
           52         SSH_AGENT_OLD_SIGNATURE = 0x01
           53 };
           54 
           55 typedef struct Aconn Aconn;
           56 struct Aconn
           57 {
           58         uchar *data;
           59         uint ndata;
           60         int ctl;
           61         int fd;
           62         char dir[40];
           63 };
           64 
           65 typedef struct Msg Msg;
           66 struct Msg
           67 {
           68         uchar *bp;
           69         uchar *p;
           70         uchar *ep;
           71         int bpalloc;
           72 };
           73 
           74 char adir[40];
           75 int afd;
           76 int chatty;
           77 char *factotum = "factotum";
           78 
           79 void                agentproc(void *v);
           80 void*        emalloc(int n);
           81 void*        erealloc(void *v, int n);
           82 void                listenproc(void *v);
           83 int                runmsg(Aconn *a);
           84 void                listkeystext(void);
           85 
           86 void
           87 usage(void)
           88 {
           89         fprint(2, "usage: 9 ssh-agent [-D] [factotum]\n");
           90         threadexitsall("usage");
           91 }
           92 
           93 int
           94 threadmaybackground(void)
           95 {
           96         return 1;
           97 }
           98 
           99 void
          100 threadmain(int argc, char **argv)
          101 {
          102         int fd, pid, export, dotextlist;
          103         char dir[100], *ns;
          104         char sock[200], addr[200];
          105         uvlong x;
          106 
          107         export = 0;
          108         dotextlist = 0;
          109         pid = getpid();
          110         fmtinstall('B', mpfmt);
          111         fmtinstall('H', encodefmt);
          112         fmtinstall('[', encodefmt);
          113 
          114         ARGBEGIN{
          115         case '9':
          116                 chatty9pclient++;
          117                 break;
          118         case 'D':
          119                 chatty++;
          120                 break;
          121         case 'e':
          122                 export = 1;
          123                 break;
          124         case 'l':
          125                 dotextlist = 1;
          126                 break;
          127         default:
          128                 usage();
          129         }ARGEND
          130 
          131         if(argc > 1)
          132                 usage();
          133         if(argc == 1)
          134                 factotum = argv[0];
          135 
          136         if(dotextlist)
          137                 listkeystext();
          138 
          139         ns = getns();
          140         snprint(sock, sizeof sock, "%s/ssh-agent.socket", ns);
          141         if(0){
          142                 x = ((uvlong)fastrand()<<32) | fastrand();
          143                 x ^= ((uvlong)fastrand()<<32) | fastrand();
          144                 snprint(dir, sizeof dir, "/tmp/ssh-%llux", x);
          145                 if((fd = create(dir, OREAD, DMDIR|0700)) < 0)
          146                         sysfatal("mkdir %s: %r", dir);
          147                 close(fd);
          148                 snprint(sock, sizeof sock, "%s/agent.%d", dir, pid);
          149         }
          150         snprint(addr, sizeof addr, "unix!%s", sock);
          151 
          152         if((afd = announce(addr, adir)) < 0)
          153                 sysfatal("announce %s: %r", addr);
          154 
          155         print("SSH_AUTH_SOCK=%s;\n", sock);
          156         if(export)
          157                 print("export SSH_AUTH_SOCK;\n");
          158         print("SSH_AGENT_PID=%d;\n", pid);
          159         if(export)
          160                 print("export SSH_AGENT_PID;\n");
          161         close(1);
          162         rfork(RFNOTEG);
          163         proccreate(listenproc, nil, STACK);
          164         threadexits(0);
          165 }
          166 
          167 void
          168 listenproc(void *v)
          169 {
          170         Aconn *a;
          171 
          172         USED(v);
          173         for(;;){
          174                 a = emalloc(sizeof *a);
          175                 a->ctl = listen(adir, a->dir);
          176                 if(a->ctl < 0)
          177                         sysfatal("listen: %r");
          178                 proccreate(agentproc, a, STACK);
          179         }
          180 }
          181 
          182 void
          183 agentproc(void *v)
          184 {
          185         Aconn *a;
          186         int n;
          187 
          188         a = v;
          189         a->fd = accept(a->ctl, a->dir);
          190         close(a->ctl);
          191         a->ctl = -1;
          192         for(;;){
          193                 a->data = erealloc(a->data, a->ndata+1024);
          194                 n = read(a->fd, a->data+a->ndata, 1024);
          195                 if(n <= 0)
          196                         break;
          197                 a->ndata += n;
          198                 while(runmsg(a))
          199                         ;
          200         }
          201         close(a->fd);
          202         free(a);
          203         threadexits(nil);
          204 }
          205 
          206 int
          207 get1(Msg *m)
          208 {
          209         if(m->p >= m->ep)
          210                 return 0;
          211         return *m->p++;
          212 }
          213 
          214 int
          215 get2(Msg *m)
          216 {
          217         uint x;
          218 
          219         if(m->p+2 > m->ep)
          220                 return 0;
          221         x = (m->p[0]<<8)|m->p[1];
          222         m->p += 2;
          223         return x;
          224 }
          225 
          226 int
          227 get4(Msg *m)
          228 {
          229         uint x;
          230         if(m->p+4 > m->ep)
          231                 return 0;
          232         x = (m->p[0]<<24)|(m->p[1]<<16)|(m->p[2]<<8)|m->p[3];
          233         m->p += 4;
          234         return x;
          235 }
          236 
          237 uchar*
          238 getn(Msg *m, uint n)
          239 {
          240         uchar *p;
          241 
          242         if(m->p+n > m->ep)
          243                 return nil;
          244         p = m->p;
          245         m->p += n;
          246         return p;
          247 }
          248 
          249 char*
          250 getstr(Msg *m)
          251 {
          252         uint n;
          253         uchar *p;
          254 
          255         n = get4(m);
          256         p = getn(m, n);
          257         if(p == nil)
          258                 return nil;
          259         p--;
          260         memmove(p, p+1, n);
          261         p[n] = 0;
          262         return (char*)p;
          263 }
          264 
          265 mpint*
          266 getmp(Msg *m)
          267 {
          268         int n;
          269         uchar *p;
          270 
          271         n = (get2(m)+7)/8;
          272         if((p=getn(m, n)) == nil)
          273                 return nil;
          274         return betomp(p, n, nil);
          275 }
          276 
          277 mpint*
          278 getmp2(Msg *m)
          279 {
          280         int n;
          281         uchar *p;
          282 
          283         n = get4(m);
          284         if((p = getn(m, n)) == nil)
          285                 return nil;
          286         return betomp(p, n, nil);
          287 }
          288 
          289 void
          290 newmsg(Msg *m)
          291 {
          292         memset(m, 0, sizeof *m);
          293 }
          294 
          295 void
          296 mreset(Msg *m)
          297 {
          298         if(m->bpalloc){
          299                 memset(m->bp, 0, m->ep-m->bp);
          300                 free(m->bp);
          301         }
          302         memset(m, 0, sizeof *m);
          303 }
          304 
          305 Msg*
          306 getm(Msg *m, Msg *mm)
          307 {
          308         uint n;
          309         uchar *p;
          310 
          311         n = get4(m);
          312         if((p = getn(m, n)) == nil)
          313                 return nil;
          314         mm->bp = p;
          315         mm->p = p;
          316         mm->ep = p+n;
          317         mm->bpalloc = 0;
          318         return mm;
          319 }
          320 
          321 uchar*
          322 ensure(Msg *m, int n)
          323 {
          324         int len;
          325         uchar *p;
          326         uchar *obp;
          327 
          328         if(m->bp == nil)
          329                 m->bpalloc = 1;
          330         if(!m->bpalloc){
          331                 p = emalloc(m->ep - m->bp);
          332                 memmove(p, m->bp, m->ep - m->bp);
          333                 obp = m->bp;
          334                 m->bp = p;
          335                 m->ep += m->bp - obp;
          336                 m->p += m->bp - obp;
          337                 m->bpalloc = 1;
          338         }
          339         len = m->ep - m->bp;
          340         if(m->p+n > m->ep){
          341                 obp = m->bp;
          342                 m->bp = erealloc(m->bp, len+n+1024);
          343                 m->p += m->bp - obp;
          344                 m->ep += m->bp - obp;
          345                 m->ep += n+1024;
          346         }
          347         p = m->p;
          348         m->p += n;
          349         return p;
          350 }
          351 
          352 void
          353 put4(Msg *m, uint n)
          354 {
          355         uchar *p;
          356 
          357         p = ensure(m, 4);
          358         p[0] = (n>>24)&0xFF;
          359         p[1] = (n>>16)&0xFF;
          360         p[2] = (n>>8)&0xFF;
          361         p[3] = n&0xFF;
          362 }
          363 
          364 void
          365 put2(Msg *m, uint n)
          366 {
          367         uchar *p;
          368 
          369         p = ensure(m, 2);
          370         p[0] = (n>>8)&0xFF;
          371         p[1] = n&0xFF;
          372 }
          373 
          374 void
          375 put1(Msg *m, uint n)
          376 {
          377         uchar *p;
          378 
          379         p = ensure(m, 1);
          380         p[0] = n&0xFF;
          381 }
          382 
          383 void
          384 putn(Msg *m, void *a, uint n)
          385 {
          386         uchar *p;
          387 
          388         p = ensure(m, n);
          389         memmove(p, a, n);
          390 }
          391 
          392 void
          393 putmp(Msg *m, mpint *b)
          394 {
          395         int bits, n;
          396         uchar *p;
          397 
          398         bits = mpsignif(b);
          399         put2(m, bits);
          400         n = (bits+7)/8;
          401         p = ensure(m, n);
          402         mptobe(b, p, n, nil);
          403 }
          404 
          405 void
          406 putmp2(Msg *m, mpint *b)
          407 {
          408         int bits, n;
          409         uchar *p;
          410 
          411         if(mpcmp(b, mpzero) == 0){
          412                 put4(m, 0);
          413                 return;
          414         }
          415         bits = mpsignif(b);
          416         n = (bits+7)/8;
          417         if(bits%8 == 0){
          418                 put4(m, n+1);
          419                 put1(m, 0);
          420         }else
          421                 put4(m, n);
          422         p = ensure(m, n);
          423         mptobe(b, p, n, nil);
          424 }
          425 
          426 void
          427 putstr(Msg *m, char *s)
          428 {
          429         int n;
          430 
          431         n = strlen(s);
          432         put4(m, n);
          433         putn(m, s, n);
          434 }
          435 
          436 void
          437 putm(Msg *m, Msg *mm)
          438 {
          439         uint n;
          440 
          441         n = mm->p - mm->bp;
          442         put4(m, n);
          443         putn(m, mm->bp, n);
          444 }
          445 
          446 void
          447 newreply(Msg *m, int type)
          448 {
          449         memset(m, 0, sizeof *m);
          450         put4(m, 0);
          451         put1(m, type);
          452 }
          453 
          454 void
          455 reply(Aconn *a, Msg *m)
          456 {
          457         uint n;
          458         uchar *p;
          459 
          460         n = (m->p - m->bp) - 4;
          461         p = m->bp;
          462         p[0] = (n>>24)&0xFF;
          463         p[1] = (n>>16)&0xFF;
          464         p[2] = (n>>8)&0xFF;
          465         p[3] = n&0xFF;
          466         if(chatty)
          467                 fprint(2, "respond %d t=%d: %.*H\n", n, p[4], n, m->bp+4);
          468         write(a->fd, p, n+4);
          469         mreset(m);
          470 }
          471 
          472 typedef struct Key Key;
          473 struct Key
          474 {
          475         mpint *mod;
          476         mpint *ek;
          477         char *comment;
          478 };
          479 
          480 static char*
          481 find(char **f, int nf, char *k)
          482 {
          483         int i, len;
          484 
          485         len = strlen(k);
          486         for(i=1; i<nf; i++)        /* i=1: f[0] is "key" */
          487                 if(strncmp(f[i], k, len) == 0 && f[i][len] == '=')
          488                         return f[i]+len+1;
          489         return nil;
          490 }
          491 
          492 static int
          493 putrsa1(Msg *m, char **f, int nf)
          494 {
          495         char *p;
          496         mpint *mod, *ek;
          497 
          498         p = find(f, nf, "n");
          499         if(p == nil || (mod = strtomp(p, nil, 16, nil)) == nil)
          500                 return -1;
          501         p = find(f, nf, "ek");
          502         if(p == nil || (ek = strtomp(p, nil, 16, nil)) == nil){
          503                 mpfree(mod);
          504                 return -1;
          505         }
          506         p = find(f, nf, "comment");
          507         if(p == nil)
          508                 p = "";
          509         put4(m, mpsignif(mod));
          510         putmp(m, ek);
          511         putmp(m, mod);
          512         putstr(m, p);
          513         mpfree(mod);
          514         mpfree(ek);
          515         return 0;
          516 }
          517 
          518 void
          519 printattr(char **f, int nf)
          520 {
          521         int i;
          522 
          523         print("#");
          524         for(i=0; i<nf; i++)
          525                 print(" %s", f[i]);
          526         print("\n");
          527 }
          528 
          529 void
          530 printrsa1(char **f, int nf)
          531 {
          532         char *p;
          533         mpint *mod, *ek;
          534 
          535         p = find(f, nf, "n");
          536         if(p == nil || (mod = strtomp(p, nil, 16, nil)) == nil)
          537                 return;
          538         p = find(f, nf, "ek");
          539         if(p == nil || (ek = strtomp(p, nil, 16, nil)) == nil){
          540                 mpfree(mod);
          541                 return;
          542         }
          543         p = find(f, nf, "comment");
          544         if(p == nil)
          545                 p = "";
          546 
          547         if(chatty)
          548                 printattr(f, nf);
          549         print("%d %.10B %.10B %s\n", mpsignif(mod), ek, mod, p);
          550         mpfree(ek);
          551         mpfree(mod);
          552 }
          553 
          554 static int
          555 putrsa(Msg *m, char **f, int nf)
          556 {
          557         char *p;
          558         mpint *mod, *ek;
          559 
          560         p = find(f, nf, "n");
          561         if(p == nil || (mod = strtomp(p, nil, 16, nil)) == nil)
          562                 return -1;
          563         p = find(f, nf, "ek");
          564         if(p == nil || (ek = strtomp(p, nil, 16, nil)) == nil){
          565                 mpfree(mod);
          566                 return -1;
          567         }
          568         putstr(m, "ssh-rsa");
          569         putmp2(m, ek);
          570         putmp2(m, mod);
          571         mpfree(ek);
          572         mpfree(mod);
          573         return 0;
          574 }
          575 
          576 RSApub*
          577 getrsapub(Msg *m)
          578 {
          579         RSApub *k;
          580 
          581         k = rsapuballoc();
          582         if(k == nil)
          583                 return nil;
          584         k->ek = getmp2(m);
          585         k->n = getmp2(m);
          586         if(k->ek == nil || k->n == nil){
          587                 rsapubfree(k);
          588                 return nil;
          589         }
          590         return k;
          591 }
          592 
          593 static int
          594 putdsa(Msg *m, char **f, int nf)
          595 {
          596         char *p;
          597         int ret;
          598         mpint *dp, *dq, *dalpha, *dkey;
          599 
          600         ret = -1;
          601         dp = dq = dalpha = dkey = nil;
          602         p = find(f, nf, "p");
          603         if(p == nil || (dp = strtomp(p, nil, 16, nil)) == nil)
          604                 goto out;
          605         p = find(f, nf, "q");
          606         if(p == nil || (dq = strtomp(p, nil, 16, nil)) == nil)
          607                 goto out;
          608         p = find(f, nf, "alpha");
          609         if(p == nil || (dalpha = strtomp(p, nil, 16, nil)) == nil)
          610                 goto out;
          611         p = find(f, nf, "key");
          612         if(p == nil || (dkey = strtomp(p, nil, 16, nil)) == nil)
          613                 goto out;
          614         putstr(m, "ssh-dss");
          615         putmp2(m, dp);
          616         putmp2(m, dq);
          617         putmp2(m, dalpha);
          618         putmp2(m, dkey);
          619         ret = 0;
          620 out:
          621         mpfree(dp);
          622         mpfree(dq);
          623         mpfree(dalpha);
          624         mpfree(dkey);
          625         return ret;
          626 }
          627 
          628 static int
          629 putkey2(Msg *m, int (*put)(Msg*,char**,int), char **f, int nf)
          630 {
          631         char *p;
          632         Msg mm;
          633 
          634         newmsg(&mm);
          635         if(put(&mm, f, nf) < 0)
          636                 return -1;
          637         putm(m, &mm);
          638         mreset(&mm);
          639         p = find(f, nf, "comment");
          640         if(p == nil)
          641                 p = "";
          642         putstr(m, p);
          643         return 0;
          644 }
          645 
          646 static int
          647 printkey(char *type, int (*put)(Msg*,char**,int), char **f, int nf)
          648 {
          649         Msg m;
          650         char *p;
          651 
          652         newmsg(&m);
          653         if(put(&m, f, nf) < 0)
          654                 return -1;
          655         p = find(f, nf, "comment");
          656         if(p == nil)
          657                 p = "";
          658         if(chatty)
          659                 printattr(f, nf);
          660         print("%s %.*[ %s\n", type, m.p-m.bp, m.bp, p);
          661         mreset(&m);
          662         return 0;
          663 }
          664 
          665 DSApub*
          666 getdsapub(Msg *m)
          667 {
          668         DSApub *k;
          669 
          670         k = dsapuballoc();
          671         if(k == nil)
          672                 return nil;
          673         k->p = getmp2(m);
          674         k->q = getmp2(m);
          675         k->alpha = getmp2(m);
          676         k->key = getmp2(m);
          677         if(!k->p || !k->q || !k->alpha || !k->key){
          678                 dsapubfree(k);
          679                 return nil;
          680         }
          681         return k;
          682 }
          683 
          684 static int
          685 listkeys(Msg *m, int version)
          686 {
          687         char buf[8192+1], *line[100], *f[20], *p, *s;
          688         int pnk;
          689         int i, n, nl, nf, nk;
          690         CFid *fid;
          691 
          692         nk = 0;
          693         pnk = m->p - m->bp;
          694         put4(m, 0);
          695         if((fid = nsopen(factotum, nil, "ctl", OREAD)) == nil){
          696                 fprint(2, "ssh-agent: open factotum: %r\n");
          697                 return -1;
          698         }
          699         for(;;){
          700                 if((n = fsread(fid, buf, sizeof buf-1)) <= 0)
          701                         break;
          702                 buf[n] = 0;
          703                 nl = getfields(buf, line, nelem(line), 1, "\n");
          704                 for(i=0; i<nl; i++){
          705                         nf = tokenize(line[i], f, nelem(f));
          706                         if(nf == 0 || strcmp(f[0], "key") != 0)
          707                                 continue;
          708                         p = find(f, nf, "proto");
          709                         if(p == nil)
          710                                 continue;
          711                         s = find(f, nf, "service");
          712                         if(s == nil)
          713                                 continue;
          714 
          715                         if(version == 1 && strcmp(p, "rsa") == 0 && strcmp(s, "ssh") == 0)
          716                                 if(putrsa1(m, f, nf) >= 0)
          717                                         nk++;
          718                         if(version == 2 && strcmp(p, "rsa") == 0 && strcmp(s, "ssh-rsa") == 0)
          719                                 if(putkey2(m, putrsa, f, nf) >= 0)
          720                                         nk++;
          721                         if(version == 2 && strcmp(p, "dsa") == 0 && strcmp(s, "ssh-dss") == 0)
          722                                 if(putkey2(m, putdsa, f, nf) >= 0)
          723                                         nk++;
          724                 }
          725         }
          726         if(chatty)
          727                 fprint(2, "sending %d keys\n", nk);
          728         fsclose(fid);
          729         m->bp[pnk+0] = (nk>>24)&0xFF;
          730         m->bp[pnk+1] = (nk>>16)&0xFF;
          731         m->bp[pnk+2] = (nk>>8)&0xFF;
          732         m->bp[pnk+3] = nk&0xFF;
          733         return nk;
          734 }
          735 
          736 void
          737 listkeystext(void)
          738 {
          739         char buf[8192+1], *line[100], *f[20], *p, *s;
          740         int i, n, nl, nf;
          741         CFid *fid;
          742 
          743         if((fid = nsopen(factotum, nil, "ctl", OREAD)) == nil){
          744                 fprint(2, "ssh-agent: open factotum: %r\n");
          745                 return;
          746         }
          747         for(;;){
          748                 if((n = fsread(fid, buf, sizeof buf-1)) <= 0)
          749                         break;
          750                 buf[n] = 0;
          751                 nl = getfields(buf, line, nelem(line), 1, "\n");
          752                 for(i=0; i<nl; i++){
          753                         nf = tokenize(line[i], f, nelem(f));
          754                         if(nf == 0 || strcmp(f[0], "key") != 0)
          755                                 continue;
          756                         p = find(f, nf, "proto");
          757                         if(p == nil)
          758                                 continue;
          759                         s = find(f, nf, "service");
          760                         if(s == nil)
          761                                 continue;
          762 
          763                         if(strcmp(p, "rsa") == 0 && strcmp(s, "ssh") == 0)
          764                                 printrsa1(f, nf);
          765                         if(strcmp(p, "rsa") == 0 && strcmp(s, "ssh-rsa") == 0)
          766                                 printkey("ssh-rsa", putrsa, f, nf);
          767                         if(strcmp(p, "dsa") == 0 && strcmp(s, "ssh-dss") == 0)
          768                                 printkey("ssh-dss", putdsa, f, nf);
          769                 }
          770         }
          771         fsclose(fid);
          772         threadexitsall(nil);
          773 }
          774 
          775 mpint*
          776 rsaunpad(mpint *b)
          777 {
          778         int i, n;
          779         uchar buf[2560];
          780 
          781         n = (mpsignif(b)+7)/8;
          782         if(n > sizeof buf){
          783                 werrstr("rsaunpad: too big");
          784                 return nil;
          785         }
          786         mptobe(b, buf, n, nil);
          787 
          788         /* the initial zero has been eaten by the betomp -> mptobe sequence */
          789         if(buf[0] != 2){
          790                 werrstr("rsaunpad: expected leading 2");
          791                 return nil;
          792         }
          793         for(i=1; i<n; i++)
          794                 if(buf[i]==0)
          795                         break;
          796         return betomp(buf+i, n-i, nil);
          797 }
          798 
          799 void
          800 mptoberjust(mpint *b, uchar *buf, int len)
          801 {
          802         int n;
          803 
          804         n = mptobe(b, buf, len, nil);
          805         assert(n >= 0);
          806         if(n < len){
          807                 len -= n;
          808                 memmove(buf+len, buf, n);
          809                 memset(buf, 0, len);
          810         }
          811 }
          812 
          813 static int
          814 dorsa(Aconn *a, mpint *mod, mpint *exp, mpint *chal, uchar chalbuf[32])
          815 {
          816         AuthRpc *rpc;
          817         char buf[4096], *p;
          818         mpint *decr, *unpad;
          819 
          820         USED(exp);
          821         if((rpc = auth_allocrpc()) == nil){
          822                 fprint(2, "ssh-agent: auth_allocrpc: %r\n");
          823                 return -1;
          824         }
          825         snprint(buf, sizeof buf, "proto=rsa service=ssh role=decrypt n=%lB ek=%lB", mod, exp);
          826         if(chatty)
          827                 fprint(2, "ssh-agent: start %s\n", buf);
          828         if(auth_rpc(rpc, "start", buf, strlen(buf)) != ARok){
          829                 fprint(2, "ssh-agent: auth 'start' failed: %r\n");
          830         Die:
          831                 auth_freerpc(rpc);
          832                 return -1;
          833         }
          834 
          835         p = mptoa(chal, 16, nil, 0);
          836         if(p == nil){
          837                 fprint(2, "ssh-agent: dorsa: mptoa: %r\n");
          838                 goto Die;
          839         }
          840         if(chatty)
          841                 fprint(2, "ssh-agent: challenge %B => %s\n", chal, p);
          842         if(auth_rpc(rpc, "writehex", p, strlen(p)) != ARok){
          843                 fprint(2, "ssh-agent: dorsa: auth 'write': %r\n");
          844                 free(p);
          845                 goto Die;
          846         }
          847         free(p);
          848         if(auth_rpc(rpc, "readhex", nil, 0) != ARok){
          849                 fprint(2, "ssh-agent: dorsa: auth 'read': %r\n");
          850                 goto Die;
          851         }
          852         decr = strtomp(rpc->arg, nil, 16, nil);
          853         if(chatty)
          854                 fprint(2, "ssh-agent: response %s => %B\n", rpc->arg, decr);
          855         if(decr == nil){
          856                 fprint(2, "ssh-agent: dorsa: strtomp: %r\n");
          857                 goto Die;
          858         }
          859         unpad = rsaunpad(decr);
          860         if(chatty)
          861                 fprint(2, "ssh-agent: unpad %B => %B\n", decr, unpad);
          862         if(unpad == nil){
          863                 fprint(2, "ssh-agent: dorsa: rsaunpad: %r\n");
          864                 mpfree(decr);
          865                 goto Die;
          866         }
          867         mpfree(decr);
          868         mptoberjust(unpad, chalbuf, 32);
          869         mpfree(unpad);
          870         auth_freerpc(rpc);
          871         return 0;
          872 }
          873 
          874 int
          875 keysign(Msg *mkey, Msg *mdata, Msg *msig)
          876 {
          877         char *s;
          878         AuthRpc *rpc;
          879         RSApub *rsa;
          880         DSApub *dsa;
          881         char buf[4096];
          882         uchar digest[SHA1dlen];
          883 
          884         s = getstr(mkey);
          885         if(strcmp(s, "ssh-rsa") == 0){
          886                 rsa = getrsapub(mkey);
          887                 if(rsa == nil)
          888                         return -1;
          889                 snprint(buf, sizeof buf, "proto=rsa service=ssh-rsa role=sign n=%lB ek=%lB",
          890                         rsa->n, rsa->ek);
          891                 rsapubfree(rsa);
          892         }else if(strcmp(s, "ssh-dss") == 0){
          893                 dsa = getdsapub(mkey);
          894                 if(dsa == nil)
          895                         return -1;
          896                 snprint(buf, sizeof buf, "proto=dsa service=ssh-dss role=sign p=%lB q=%lB alpha=%lB key=%lB",
          897                         dsa->p, dsa->q, dsa->alpha, dsa->key);
          898                 dsapubfree(dsa);
          899         }else{
          900                 fprint(2, "ssh-agent: cannot sign key type %s\n", s);
          901                 werrstr("unknown key type %s", s);
          902                 return -1;
          903         }
          904 
          905         if((rpc = auth_allocrpc()) == nil){
          906                 fprint(2, "ssh-agent: auth_allocrpc: %r\n");
          907                 return -1;
          908         }
          909         if(chatty)
          910                 fprint(2, "ssh-agent: start %s\n", buf);
          911         if(auth_rpc(rpc, "start", buf, strlen(buf)) != ARok){
          912                 fprint(2, "ssh-agent: auth 'start' failed: %r\n");
          913         Die:
          914                 auth_freerpc(rpc);
          915                 return -1;
          916         }
          917         sha1(mdata->bp, mdata->ep-mdata->bp, digest, nil);
          918         if(auth_rpc(rpc, "write", digest, SHA1dlen) != ARok){
          919                 fprint(2, "ssh-agent: auth 'write in sign failed: %r\n");
          920                 goto Die;
          921         }
          922         if(auth_rpc(rpc, "read", nil, 0) != ARok){
          923                 fprint(2, "ssh-agent: auth 'read' failed: %r\n");
          924                 goto Die;
          925         }
          926         newmsg(msig);
          927         putstr(msig, s);
          928         put4(msig, rpc->narg);
          929         putn(msig, rpc->arg, rpc->narg);
          930         auth_freerpc(rpc);
          931         return 0;
          932 }
          933 
          934 int
          935 runmsg(Aconn *a)
          936 {
          937         char *p;
          938         int n, nk, type, rt, vers;
          939         mpint *ek, *mod, *chal;
          940         uchar sessid[16], chalbuf[32], digest[MD5dlen];
          941         uint len, flags;
          942         DigestState *s;
          943         Msg m, mkey, mdata, msig;
          944 
          945         if(a->ndata < 4)
          946                 return 0;
          947         len = (a->data[0]<<24)|(a->data[1]<<16)|(a->data[2]<<8)|a->data[3];
          948         if(a->ndata < 4+len)
          949                 return 0;
          950         m.p = a->data+4;
          951         m.ep = m.p+len;
          952         type = get1(&m);
          953         if(chatty)
          954                 fprint(2, "msg %d: %.*H\n", type, len, m.p);
          955         switch(type){
          956         default:
          957         Failure:
          958                 newreply(&m, SSH_AGENT_FAILURE);
          959                 reply(a, &m);
          960                 break;
          961 
          962         case SSH_AGENTC_REQUEST_RSA_IDENTITIES:
          963                 vers = 1;
          964                 newreply(&m, SSH_AGENT_RSA_IDENTITIES_ANSWER);
          965                 goto Identities;
          966         case SSH2_AGENTC_REQUEST_IDENTITIES:
          967                 vers = 2;
          968                 newreply(&m, SSH2_AGENT_IDENTITIES_ANSWER);
          969         Identities:
          970                 nk = listkeys(&m, vers);
          971                 if(nk < 0){
          972                         mreset(&m);
          973                         goto Failure;
          974                 }
          975                 if(chatty)
          976                         fprint(2, "request identities\n", nk);
          977                 reply(a, &m);
          978                 break;
          979 
          980         case SSH_AGENTC_RSA_CHALLENGE:
          981                 n = get4(&m);
          982                 USED(n);
          983                 ek = getmp(&m);
          984                 mod = getmp(&m);
          985                 chal = getmp(&m);
          986                 if((p = (char*)getn(&m, 16)) == nil){
          987                 Failchal:
          988                         mpfree(ek);
          989                         mpfree(mod);
          990                         mpfree(chal);
          991                         goto Failure;
          992                 }
          993                 memmove(sessid, p, 16);
          994                 rt = get4(&m);
          995                 if(rt != 1 || dorsa(a, mod, ek, chal, chalbuf) < 0)
          996                         goto Failchal;
          997                 s = md5(chalbuf, 32, nil, nil);
          998                 if(s == nil)
          999                         goto Failchal;
         1000                 md5(sessid, 16, digest, s);
         1001                 print("md5 %.*H %.*H => %.*H\n", 32, chalbuf, 16, sessid, MD5dlen, digest);
         1002 
         1003                 newreply(&m, SSH_AGENT_RSA_RESPONSE);
         1004                 putn(&m, digest, 16);
         1005                 reply(a, &m);
         1006 
         1007                 mpfree(ek);
         1008                 mpfree(mod);
         1009                 mpfree(chal);
         1010                 break;
         1011 
         1012         case SSH2_AGENTC_SIGN_REQUEST:
         1013                 if(getm(&m, &mkey) == nil
         1014                 || getm(&m, &mdata) == nil)
         1015                         goto Failure;
         1016                 flags = get4(&m);
         1017                 if(flags & SSH_AGENT_OLD_SIGNATURE)
         1018                         goto Failure;
         1019                 if(keysign(&mkey, &mdata, &msig) < 0)
         1020                         goto Failure;
         1021                 if(chatty)
         1022                         fprint(2, "signature: %.*H\n",
         1023                                 msig.p-msig.bp, msig.bp);
         1024                 newreply(&m, SSH2_AGENT_SIGN_RESPONSE);
         1025                 putm(&m, &msig);
         1026                 mreset(&msig);
         1027                 reply(a, &m);
         1028                 break;
         1029 
         1030         case SSH_AGENTC_ADD_RSA_IDENTITY:
         1031                 /*
         1032                         msg: n[4] mod[mp] pubexp[exp] privexp[mp]
         1033                                 p^-1 mod q[mp] p[mp] q[mp] comment[str]
         1034                  */
         1035                 goto Failure;
         1036 
         1037         case SSH_AGENTC_REMOVE_RSA_IDENTITY:
         1038                 /*
         1039                         msg: n[4] mod[mp] pubexp[mp]
         1040                  */
         1041                 goto Failure;
         1042 
         1043         }
         1044 
         1045         a->ndata -= 4+len;
         1046         memmove(a->data, a->data+4+len, a->ndata);
         1047         return 1;
         1048 }
         1049 
         1050 void*
         1051 emalloc(int n)
         1052 {
         1053         void *v;
         1054 
         1055         v = mallocz(n, 1);
         1056         if(v == nil){
         1057                 abort();
         1058                 sysfatal("out of memory allocating %d", n);
         1059         }
         1060         return v;
         1061 }
         1062 
         1063 void*
         1064 erealloc(void *v, int n)
         1065 {
         1066         v = realloc(v, n);
         1067         if(v == nil){
         1068                 abort();
         1069                 sysfatal("out of memory reallocating %d", n);
         1070         }
         1071         return v;
         1072 }