URI:
       ttlshand.c - plan9port - [fork] Plan 9 from user space
  HTML git clone git://src.adamsgaard.dk/plan9port
   DIR Log
   DIR Files
   DIR Refs
   DIR README
   DIR LICENSE
       ---
       ttlshand.c (54440B)
       ---
            1 #include <u.h>
            2 #include <libc.h>
            3 #include <bio.h>
            4 #include <auth.h>
            5 #include <mp.h>
            6 #include <libsec.h>
            7 
            8 /* The main groups of functions are: */
            9 /*                client/server - main handshake protocol definition */
           10 /*                message functions - formating handshake messages */
           11 /*                cipher choices - catalog of digest and encrypt algorithms */
           12 /*                security functions - PKCS#1, sslHMAC, session keygen */
           13 /*                general utility functions - malloc, serialization */
           14 /* The handshake protocol builds on the TLS/SSL3 record layer protocol, */
           15 /* which is implemented in kernel device #a.  See also /lib/rfc/rfc2246. */
           16 
           17 enum {
           18         TLSFinishedLen = 12,
           19         SSL3FinishedLen = MD5dlen+SHA1dlen,
           20         MaxKeyData = 104,        /* amount of secret we may need */
           21         MaxChunk = 1<<14,
           22         RandomSize = 32,
           23         SidSize = 32,
           24         MasterSecretSize = 48,
           25         AQueue = 0,
           26         AFlush = 1
           27 };
           28 
           29 typedef struct TlsSec TlsSec;
           30 
           31 typedef struct Bytes{
           32         int len;
           33         uchar data[1];  /* [len] */
           34 } Bytes;
           35 
           36 typedef struct Ints{
           37         int len;
           38         int data[1];  /* [len] */
           39 } Ints;
           40 
           41 typedef struct Algs{
           42         char *enc;
           43         char *digest;
           44         int nsecret;
           45         int tlsid;
           46         int ok;
           47 } Algs;
           48 
           49 typedef struct Finished{
           50         uchar verify[SSL3FinishedLen];
           51         int n;
           52 } Finished;
           53 
           54 typedef struct TlsConnection{
           55         TlsSec *sec;        /* security management goo */
           56         int hand, ctl;        /* record layer file descriptors */
           57         int erred;                /* set when tlsError called */
           58         int (*trace)(char*fmt, ...); /* for debugging */
           59         int version;        /* protocol we are speaking */
           60         int verset;                /* version has been set */
           61         int ver2hi;                /* server got a version 2 hello */
           62         int isClient;        /* is this the client or server? */
           63         Bytes *sid;                /* SessionID */
           64         Bytes *cert;        /* only last - no chain */
           65 
           66         Lock statelk;
           67         int state;                /* must be set using setstate */
           68 
           69         /* input buffer for handshake messages */
           70         uchar buf[MaxChunk+2048];
           71         uchar *rp, *ep;
           72 
           73         uchar crandom[RandomSize];        /* client random */
           74         uchar srandom[RandomSize];        /* server random */
           75         int clientVersion;        /* version in ClientHello */
           76         char *digest;        /* name of digest algorithm to use */
           77         char *enc;                /* name of encryption algorithm to use */
           78         int nsecret;        /* amount of secret data to init keys */
           79 
           80         /* for finished messages */
           81         MD5state        hsmd5;        /* handshake hash */
           82         SHAstate        hssha1;        /* handshake hash */
           83         Finished        finished;
           84 } TlsConnection;
           85 
           86 typedef struct Msg{
           87         int tag;
           88         union {
           89                 struct {
           90                         int version;
           91                         uchar         random[RandomSize];
           92                         Bytes*        sid;
           93                         Ints*        ciphers;
           94                         Bytes*        compressors;
           95                 } clientHello;
           96                 struct {
           97                         int version;
           98                         uchar         random[RandomSize];
           99                         Bytes*        sid;
          100                         int cipher;
          101                         int compressor;
          102                 } serverHello;
          103                 struct {
          104                         int ncert;
          105                         Bytes **certs;
          106                 } certificate;
          107                 struct {
          108                         Bytes *types;
          109                         int nca;
          110                         Bytes **cas;
          111                 } certificateRequest;
          112                 struct {
          113                         Bytes *key;
          114                 } clientKeyExchange;
          115                 Finished finished;
          116         } u;
          117 } Msg;
          118 
          119 struct TlsSec{
          120         char *server;        /* name of remote; nil for server */
          121         int ok;        /* <0 killed; ==0 in progress; >0 reusable */
          122         RSApub *rsapub;
          123         AuthRpc *rpc;        /* factotum for rsa private key */
          124         uchar sec[MasterSecretSize];        /* master secret */
          125         uchar crandom[RandomSize];        /* client random */
          126         uchar srandom[RandomSize];        /* server random */
          127         int clientVers;                /* version in ClientHello */
          128         int vers;                        /* final version */
          129         /* byte generation and handshake checksum */
          130         void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int);
          131         void (*setFinished)(TlsSec*, MD5state, SHAstate, uchar*, int);
          132         int nfin;
          133 };
          134 
          135 
          136 enum {
          137         TLSVersion = 0x0301,
          138         SSL3Version = 0x0300,
          139         ProtocolVersion = 0x0301,        /* maximum version we speak */
          140         MinProtoVersion = 0x0300,        /* limits on version we accept */
          141         MaxProtoVersion        = 0x03ff
          142 };
          143 
          144 /* handshake type */
          145 enum {
          146         HHelloRequest,
          147         HClientHello,
          148         HServerHello,
          149         HSSL2ClientHello = 9,  /* local convention;  see devtls.c */
          150         HCertificate = 11,
          151         HServerKeyExchange,
          152         HCertificateRequest,
          153         HServerHelloDone,
          154         HCertificateVerify,
          155         HClientKeyExchange,
          156         HFinished = 20,
          157         HMax
          158 };
          159 
          160 /* alerts */
          161 enum {
          162         ECloseNotify = 0,
          163         EUnexpectedMessage = 10,
          164         EBadRecordMac = 20,
          165         EDecryptionFailed = 21,
          166         ERecordOverflow = 22,
          167         EDecompressionFailure = 30,
          168         EHandshakeFailure = 40,
          169         ENoCertificate = 41,
          170         EBadCertificate = 42,
          171         EUnsupportedCertificate = 43,
          172         ECertificateRevoked = 44,
          173         ECertificateExpired = 45,
          174         ECertificateUnknown = 46,
          175         EIllegalParameter = 47,
          176         EUnknownCa = 48,
          177         EAccessDenied = 49,
          178         EDecodeError = 50,
          179         EDecryptError = 51,
          180         EExportRestriction = 60,
          181         EProtocolVersion = 70,
          182         EInsufficientSecurity = 71,
          183         EInternalError = 80,
          184         EUserCanceled = 90,
          185         ENoRenegotiation = 100,
          186         EMax = 256
          187 };
          188 
          189 /* cipher suites */
          190 enum {
          191         TLS_NULL_WITH_NULL_NULL                         = 0x0000,
          192         TLS_RSA_WITH_NULL_MD5                         = 0x0001,
          193         TLS_RSA_WITH_NULL_SHA                         = 0x0002,
          194         TLS_RSA_EXPORT_WITH_RC4_40_MD5                 = 0x0003,
          195         TLS_RSA_WITH_RC4_128_MD5                 = 0x0004,
          196         TLS_RSA_WITH_RC4_128_SHA                 = 0x0005,
          197         TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5        = 0X0006,
          198         TLS_RSA_WITH_IDEA_CBC_SHA                 = 0X0007,
          199         TLS_RSA_EXPORT_WITH_DES40_CBC_SHA        = 0X0008,
          200         TLS_RSA_WITH_DES_CBC_SHA                = 0X0009,
          201         TLS_RSA_WITH_3DES_EDE_CBC_SHA                = 0X000A,
          202         TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA        = 0X000B,
          203         TLS_DH_DSS_WITH_DES_CBC_SHA                = 0X000C,
          204         TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA        = 0X000D,
          205         TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA        = 0X000E,
          206         TLS_DH_RSA_WITH_DES_CBC_SHA                = 0X000F,
          207         TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA        = 0X0010,
          208         TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA        = 0X0011,
          209         TLS_DHE_DSS_WITH_DES_CBC_SHA                = 0X0012,
          210         TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA        = 0X0013,        /* ZZZ must be implemented for tls1.0 compliance */
          211         TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA        = 0X0014,
          212         TLS_DHE_RSA_WITH_DES_CBC_SHA                = 0X0015,
          213         TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA        = 0X0016,
          214         TLS_DH_anon_EXPORT_WITH_RC4_40_MD5        = 0x0017,
          215         TLS_DH_anon_WITH_RC4_128_MD5                 = 0x0018,
          216         TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA        = 0X0019,
          217         TLS_DH_anon_WITH_DES_CBC_SHA                = 0X001A,
          218         TLS_DH_anon_WITH_3DES_EDE_CBC_SHA        = 0X001B,
          219 
          220         TLS_RSA_WITH_AES_128_CBC_SHA                = 0X002f,        /* aes, aka rijndael with 128 bit blocks */
          221         TLS_DH_DSS_WITH_AES_128_CBC_SHA                = 0X0030,
          222         TLS_DH_RSA_WITH_AES_128_CBC_SHA                = 0X0031,
          223         TLS_DHE_DSS_WITH_AES_128_CBC_SHA        = 0X0032,
          224         TLS_DHE_RSA_WITH_AES_128_CBC_SHA        = 0X0033,
          225         TLS_DH_anon_WITH_AES_128_CBC_SHA        = 0X0034,
          226         TLS_RSA_WITH_AES_256_CBC_SHA                = 0X0035,
          227         TLS_DH_DSS_WITH_AES_256_CBC_SHA                = 0X0036,
          228         TLS_DH_RSA_WITH_AES_256_CBC_SHA                = 0X0037,
          229         TLS_DHE_DSS_WITH_AES_256_CBC_SHA        = 0X0038,
          230         TLS_DHE_RSA_WITH_AES_256_CBC_SHA        = 0X0039,
          231         TLS_DH_anon_WITH_AES_256_CBC_SHA        = 0X003A,
          232         CipherMax
          233 };
          234 
          235 /* compression methods */
          236 enum {
          237         CompressionNull = 0,
          238         CompressionMax
          239 };
          240 
          241 static Algs cipherAlgs[] = {
          242         {"rc4_128", "md5",        2 * (16 + MD5dlen), TLS_RSA_WITH_RC4_128_MD5},
          243         {"rc4_128", "sha1",        2 * (16 + SHA1dlen), TLS_RSA_WITH_RC4_128_SHA},
          244         {"3des_ede_cbc","sha1",2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
          245 };
          246 
          247 static uchar compressors[] = {
          248         CompressionNull,
          249 };
          250 
          251 static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chain);
          252 static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...));
          253 
          254 static void        msgClear(Msg *m);
          255 static char* msgPrint(char *buf, int n, Msg *m);
          256 static int        msgRecv(TlsConnection *c, Msg *m);
          257 static int        msgSend(TlsConnection *c, Msg *m, int act);
          258 static void        tlsError(TlsConnection *c, int err, char *msg, ...);
          259 /* #pragma        varargck argpos        tlsError 3*/
          260 static int setVersion(TlsConnection *c, int version);
          261 static int finishedMatch(TlsConnection *c, Finished *f);
          262 static void tlsConnectionFree(TlsConnection *c);
          263 
          264 static int setAlgs(TlsConnection *c, int a);
          265 static int okCipher(Ints *cv);
          266 static int okCompression(Bytes *cv);
          267 static int initCiphers(void);
          268 static Ints* makeciphers(void);
          269 
          270 static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom);
          271 static int        tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd);
          272 static TlsSec*        tlsSecInitc(int cvers, uchar *crandom);
          273 static int        tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd);
          274 static int        tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient);
          275 static void        tlsSecOk(TlsSec *sec);
          276 /* static void        tlsSecKill(TlsSec *sec); */
          277 static void        tlsSecClose(TlsSec *sec);
          278 static void        setMasterSecret(TlsSec *sec, Bytes *pm);
          279 static void        serverMasterSecret(TlsSec *sec, uchar *epm, int nepm);
          280 static void        setSecrets(TlsSec *sec, uchar *kd, int nkd);
          281 static int        clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm);
          282 static Bytes *pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
          283 static Bytes *pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm);
          284 static void        tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
          285 static void        sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
          286 static void        sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label,
          287                         uchar *seed0, int nseed0, uchar *seed1, int nseed1);
          288 static int setVers(TlsSec *sec, int version);
          289 
          290 static AuthRpc* factotum_rsa_open(uchar *cert, int certlen);
          291 static mpint* factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher);
          292 static void factotum_rsa_close(AuthRpc*rpc);
          293 
          294 static void* emalloc(int);
          295 static void* erealloc(void*, int);
          296 static void put32(uchar *p, u32int);
          297 static void put24(uchar *p, int);
          298 static void put16(uchar *p, int);
          299 /* static u32int get32(uchar *p); */
          300 static int get24(uchar *p);
          301 static int get16(uchar *p);
          302 static Bytes* newbytes(int len);
          303 static Bytes* makebytes(uchar* buf, int len);
          304 static void freebytes(Bytes* b);
          305 static Ints* newints(int len);
          306 /* static Ints* makeints(int* buf, int len); */
          307 static void freeints(Ints* b);
          308 
          309 /*================= client/server ======================== */
          310 
          311 /*        push TLS onto fd, returning new (application) file descriptor */
          312 /*                or -1 if error. */
          313 int
          314 tlsServer(int fd, TLSconn *conn)
          315 {
          316         char buf[8];
          317         char dname[64];
          318         int n, data, ctl, hand;
          319         TlsConnection *tls;
          320 
          321         if(conn == nil)
          322                 return -1;
          323         ctl = open("#a/tls/clone", ORDWR);
          324         if(ctl < 0)
          325                 return -1;
          326         n = read(ctl, buf, sizeof(buf)-1);
          327         if(n < 0){
          328                 close(ctl);
          329                 return -1;
          330         }
          331         buf[n] = 0;
          332         sprint(conn->dir, "#a/tls/%s", buf);
          333         sprint(dname, "#a/tls/%s/hand", buf);
          334         hand = open(dname, ORDWR);
          335         if(hand < 0){
          336                 close(ctl);
          337                 return -1;
          338         }
          339         fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
          340         tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace, conn->chain);
          341         sprint(dname, "#a/tls/%s/data", buf);
          342         data = open(dname, ORDWR);
          343         close(fd);
          344         close(hand);
          345         close(ctl);
          346         if(data < 0){
          347                 return -1;
          348         }
          349         if(tls == nil){
          350                 close(data);
          351                 return -1;
          352         }
          353         if(conn->cert)
          354                 free(conn->cert);
          355         conn->cert = 0;  /* client certificates are not yet implemented */
          356         conn->certlen = 0;
          357         conn->sessionIDlen = tls->sid->len;
          358         conn->sessionID = emalloc(conn->sessionIDlen);
          359         memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
          360         tlsConnectionFree(tls);
          361         return data;
          362 }
          363 
          364 /*        push TLS onto fd, returning new (application) file descriptor */
          365 /*                or -1 if error. */
          366 int
          367 tlsClient(int fd, TLSconn *conn)
          368 {
          369         char buf[8];
          370         char dname[64];
          371         int n, data, ctl, hand;
          372         TlsConnection *tls;
          373 
          374         if(!conn)
          375                 return -1;
          376         ctl = open("#a/tls/clone", ORDWR);
          377         if(ctl < 0)
          378                 return -1;
          379         n = read(ctl, buf, sizeof(buf)-1);
          380         if(n < 0){
          381                 close(ctl);
          382                 return -1;
          383         }
          384         buf[n] = 0;
          385         sprint(conn->dir, "#a/tls/%s", buf);
          386         sprint(dname, "#a/tls/%s/hand", buf);
          387         hand = open(dname, ORDWR);
          388         if(hand < 0){
          389                 close(ctl);
          390                 return -1;
          391         }
          392         sprint(dname, "#a/tls/%s/data", buf);
          393         data = open(dname, ORDWR);
          394         if(data < 0)
          395                 return -1;
          396         fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
          397         tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->trace);
          398         close(fd);
          399         close(hand);
          400         close(ctl);
          401         if(tls == nil){
          402                 close(data);
          403                 return -1;
          404         }
          405         conn->certlen = tls->cert->len;
          406         conn->cert = emalloc(conn->certlen);
          407         memcpy(conn->cert, tls->cert->data, conn->certlen);
          408         conn->sessionIDlen = tls->sid->len;
          409         conn->sessionID = emalloc(conn->sessionIDlen);
          410         memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
          411         tlsConnectionFree(tls);
          412         return data;
          413 }
          414 
          415 static int
          416 countchain(PEMChain *p)
          417 {
          418         int i = 0;
          419 
          420         while (p) {
          421                 i++;
          422                 p = p->next;
          423         }
          424         return i;
          425 }
          426 
          427 static TlsConnection *
          428 tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chp)
          429 {
          430         TlsConnection *c;
          431         Msg m;
          432         Bytes *csid;
          433         uchar sid[SidSize], kd[MaxKeyData];
          434         char *secrets;
          435         int cipher, compressor, nsid, rv, numcerts, i;
          436 
          437         if(trace)
          438                 trace("tlsServer2\n");
          439         if(!initCiphers())
          440                 return nil;
          441         c = emalloc(sizeof(TlsConnection));
          442         c->ctl = ctl;
          443         c->hand = hand;
          444         c->trace = trace;
          445         c->version = ProtocolVersion;
          446 
          447         memset(&m, 0, sizeof(m));
          448         if(!msgRecv(c, &m)){
          449                 if(trace)
          450                         trace("initial msgRecv failed\n");
          451                 goto Err;
          452         }
          453         if(m.tag != HClientHello) {
          454                 tlsError(c, EUnexpectedMessage, "expected a client hello");
          455                 goto Err;
          456         }
          457         c->clientVersion = m.u.clientHello.version;
          458         if(trace)
          459                 trace("ClientHello version %x\n", c->clientVersion);
          460         if(setVersion(c, m.u.clientHello.version) < 0) {
          461                 tlsError(c, EIllegalParameter, "incompatible version");
          462                 goto Err;
          463         }
          464 
          465         memmove(c->crandom, m.u.clientHello.random, RandomSize);
          466         cipher = okCipher(m.u.clientHello.ciphers);
          467         if(cipher < 0) {
          468                 /* reply with EInsufficientSecurity if we know that's the case */
          469                 if(cipher == -2)
          470                         tlsError(c, EInsufficientSecurity, "cipher suites too weak");
          471                 else
          472                         tlsError(c, EHandshakeFailure, "no matching cipher suite");
          473                 goto Err;
          474         }
          475         if(!setAlgs(c, cipher)){
          476                 tlsError(c, EHandshakeFailure, "no matching cipher suite");
          477                 goto Err;
          478         }
          479         compressor = okCompression(m.u.clientHello.compressors);
          480         if(compressor < 0) {
          481                 tlsError(c, EHandshakeFailure, "no matching compressor");
          482                 goto Err;
          483         }
          484 
          485         csid = m.u.clientHello.sid;
          486         if(trace)
          487                 trace("  cipher %d, compressor %d, csidlen %d\n", cipher, compressor, csid->len);
          488         c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom);
          489         if(c->sec == nil){
          490                 tlsError(c, EHandshakeFailure, "can't initialize security: %r");
          491                 goto Err;
          492         }
          493         c->sec->rpc = factotum_rsa_open(cert, ncert);
          494         if(c->sec->rpc == nil){
          495                 tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
          496                 goto Err;
          497         }
          498         c->sec->rsapub = X509toRSApub(cert, ncert, nil, 0);
          499         msgClear(&m);
          500 
          501         m.tag = HServerHello;
          502         m.u.serverHello.version = c->version;
          503         memmove(m.u.serverHello.random, c->srandom, RandomSize);
          504         m.u.serverHello.cipher = cipher;
          505         m.u.serverHello.compressor = compressor;
          506         c->sid = makebytes(sid, nsid);
          507         m.u.serverHello.sid = makebytes(c->sid->data, c->sid->len);
          508         if(!msgSend(c, &m, AQueue))
          509                 goto Err;
          510         msgClear(&m);
          511 
          512         m.tag = HCertificate;
          513         numcerts = countchain(chp);
          514         m.u.certificate.ncert = 1 + numcerts;
          515         m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes));
          516         m.u.certificate.certs[0] = makebytes(cert, ncert);
          517         for (i = 0; i < numcerts && chp; i++, chp = chp->next)
          518                 m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen);
          519         if(!msgSend(c, &m, AQueue))
          520                 goto Err;
          521         msgClear(&m);
          522 
          523         m.tag = HServerHelloDone;
          524         if(!msgSend(c, &m, AFlush))
          525                 goto Err;
          526         msgClear(&m);
          527 
          528         if(!msgRecv(c, &m))
          529                 goto Err;
          530         if(m.tag != HClientKeyExchange) {
          531                 tlsError(c, EUnexpectedMessage, "expected a client key exchange");
          532                 goto Err;
          533         }
          534         if(tlsSecSecrets(c->sec, c->version, m.u.clientKeyExchange.key->data, m.u.clientKeyExchange.key->len, kd, c->nsecret) < 0){
          535                 tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
          536                 goto Err;
          537         }
          538         if(trace)
          539                 trace("tls secrets\n");
          540         secrets = (char*)emalloc(2*c->nsecret);
          541         enc64(secrets, 2*c->nsecret, kd, c->nsecret);
          542         rv = fprint(c->ctl, "secret %s %s 0 %s", c->digest, c->enc, secrets);
          543         memset(secrets, 0, 2*c->nsecret);
          544         free(secrets);
          545         memset(kd, 0, c->nsecret);
          546         if(rv < 0){
          547                 tlsError(c, EHandshakeFailure, "can't set keys: %r");
          548                 goto Err;
          549         }
          550         msgClear(&m);
          551 
          552         /* no CertificateVerify; skip to Finished */
          553         if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
          554                 tlsError(c, EInternalError, "can't set finished: %r");
          555                 goto Err;
          556         }
          557         if(!msgRecv(c, &m))
          558                 goto Err;
          559         if(m.tag != HFinished) {
          560                 tlsError(c, EUnexpectedMessage, "expected a finished");
          561                 goto Err;
          562         }
          563         if(!finishedMatch(c, &m.u.finished)) {
          564                 tlsError(c, EHandshakeFailure, "finished verification failed");
          565                 goto Err;
          566         }
          567         msgClear(&m);
          568 
          569         /* change cipher spec */
          570         if(fprint(c->ctl, "changecipher") < 0){
          571                 tlsError(c, EInternalError, "can't enable cipher: %r");
          572                 goto Err;
          573         }
          574 
          575         if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
          576                 tlsError(c, EInternalError, "can't set finished: %r");
          577                 goto Err;
          578         }
          579         m.tag = HFinished;
          580         m.u.finished = c->finished;
          581         if(!msgSend(c, &m, AFlush))
          582                 goto Err;
          583         msgClear(&m);
          584         if(trace)
          585                 trace("tls finished\n");
          586 
          587         if(fprint(c->ctl, "opened") < 0)
          588                 goto Err;
          589         tlsSecOk(c->sec);
          590         return c;
          591 
          592 Err:
          593         msgClear(&m);
          594         tlsConnectionFree(c);
          595         return 0;
          596 }
          597 
          598 static TlsConnection *
          599 tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...))
          600 {
          601         TlsConnection *c;
          602         Msg m;
          603         uchar kd[MaxKeyData], *epm;
          604         char *secrets;
          605         int creq, nepm, rv;
          606 
          607         if(!initCiphers())
          608                 return nil;
          609         epm = nil;
          610         c = emalloc(sizeof(TlsConnection));
          611         c->version = ProtocolVersion;
          612         c->ctl = ctl;
          613         c->hand = hand;
          614         c->trace = trace;
          615         c->isClient = 1;
          616         c->clientVersion = c->version;
          617 
          618         c->sec = tlsSecInitc(c->clientVersion, c->crandom);
          619         if(c->sec == nil)
          620                 goto Err;
          621 
          622         /* client hello */
          623         memset(&m, 0, sizeof(m));
          624         m.tag = HClientHello;
          625         m.u.clientHello.version = c->clientVersion;
          626         memmove(m.u.clientHello.random, c->crandom, RandomSize);
          627         m.u.clientHello.sid = makebytes(csid, ncsid);
          628         m.u.clientHello.ciphers = makeciphers();
          629         m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
          630         if(!msgSend(c, &m, AFlush))
          631                 goto Err;
          632         msgClear(&m);
          633 
          634         /* server hello */
          635         if(!msgRecv(c, &m))
          636                 goto Err;
          637         if(m.tag != HServerHello) {
          638                 tlsError(c, EUnexpectedMessage, "expected a server hello");
          639                 goto Err;
          640         }
          641         if(setVersion(c, m.u.serverHello.version) < 0) {
          642                 tlsError(c, EIllegalParameter, "incompatible version %r");
          643                 goto Err;
          644         }
          645         memmove(c->srandom, m.u.serverHello.random, RandomSize);
          646         c->sid = makebytes(m.u.serverHello.sid->data, m.u.serverHello.sid->len);
          647         if(c->sid->len != 0 && c->sid->len != SidSize) {
          648                 tlsError(c, EIllegalParameter, "invalid server session identifier");
          649                 goto Err;
          650         }
          651         if(!setAlgs(c, m.u.serverHello.cipher)) {
          652                 tlsError(c, EIllegalParameter, "invalid cipher suite");
          653                 goto Err;
          654         }
          655         if(m.u.serverHello.compressor != CompressionNull) {
          656                 tlsError(c, EIllegalParameter, "invalid compression");
          657                 goto Err;
          658         }
          659         msgClear(&m);
          660 
          661         /* certificate */
          662         if(!msgRecv(c, &m) || m.tag != HCertificate) {
          663                 tlsError(c, EUnexpectedMessage, "expected a certificate");
          664                 goto Err;
          665         }
          666         if(m.u.certificate.ncert < 1) {
          667                 tlsError(c, EIllegalParameter, "runt certificate");
          668                 goto Err;
          669         }
          670         c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
          671         msgClear(&m);
          672 
          673         /* server key exchange (optional) */
          674         if(!msgRecv(c, &m))
          675                 goto Err;
          676         if(m.tag == HServerKeyExchange) {
          677                 tlsError(c, EUnexpectedMessage, "got an server key exchange");
          678                 goto Err;
          679                 /* If implementing this later, watch out for rollback attack */
          680                 /* described in Wagner Schneier 1996, section 4.4. */
          681         }
          682 
          683         /* certificate request (optional) */
          684         creq = 0;
          685         if(m.tag == HCertificateRequest) {
          686                 creq = 1;
          687                 msgClear(&m);
          688                 if(!msgRecv(c, &m))
          689                         goto Err;
          690         }
          691 
          692         if(m.tag != HServerHelloDone) {
          693                 tlsError(c, EUnexpectedMessage, "expected a server hello done");
          694                 goto Err;
          695         }
          696         msgClear(&m);
          697 
          698         if(tlsSecSecretc(c->sec, c->sid->data, c->sid->len, c->srandom,
          699                         c->cert->data, c->cert->len, c->version, &epm, &nepm,
          700                         kd, c->nsecret) < 0){
          701                 tlsError(c, EBadCertificate, "invalid x509/rsa certificate");
          702                 goto Err;
          703         }
          704         secrets = (char*)emalloc(2*c->nsecret);
          705         enc64(secrets, 2*c->nsecret, kd, c->nsecret);
          706         rv = fprint(c->ctl, "secret %s %s 1 %s", c->digest, c->enc, secrets);
          707         memset(secrets, 0, 2*c->nsecret);
          708         free(secrets);
          709         memset(kd, 0, c->nsecret);
          710         if(rv < 0){
          711                 tlsError(c, EHandshakeFailure, "can't set keys: %r");
          712                 goto Err;
          713         }
          714 
          715         if(creq) {
          716                 /* send a zero length certificate */
          717                 m.tag = HCertificate;
          718                 if(!msgSend(c, &m, AFlush))
          719                         goto Err;
          720                 msgClear(&m);
          721         }
          722 
          723         /* client key exchange */
          724         m.tag = HClientKeyExchange;
          725         m.u.clientKeyExchange.key = makebytes(epm, nepm);
          726         free(epm);
          727         epm = nil;
          728         if(m.u.clientKeyExchange.key == nil) {
          729                 tlsError(c, EHandshakeFailure, "can't set secret: %r");
          730                 goto Err;
          731         }
          732         if(!msgSend(c, &m, AFlush))
          733                 goto Err;
          734         msgClear(&m);
          735 
          736         /* change cipher spec */
          737         if(fprint(c->ctl, "changecipher") < 0){
          738                 tlsError(c, EInternalError, "can't enable cipher: %r");
          739                 goto Err;
          740         }
          741 
          742         /* Cipherchange must occur immediately before Finished to avoid */
          743         /* potential hole;  see section 4.3 of Wagner Schneier 1996. */
          744         if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
          745                 tlsError(c, EInternalError, "can't set finished 1: %r");
          746                 goto Err;
          747         }
          748         m.tag = HFinished;
          749         m.u.finished = c->finished;
          750 
          751         if(!msgSend(c, &m, AFlush)) {
          752                 fprint(2, "tlsClient nepm=%d\n", nepm);
          753                 tlsError(c, EInternalError, "can't flush after client Finished: %r");
          754                 goto Err;
          755         }
          756         msgClear(&m);
          757 
          758         if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
          759                 fprint(2, "tlsClient nepm=%d\n", nepm);
          760                 tlsError(c, EInternalError, "can't set finished 0: %r");
          761                 goto Err;
          762         }
          763         if(!msgRecv(c, &m)) {
          764                 fprint(2, "tlsClient nepm=%d\n", nepm);
          765                 tlsError(c, EInternalError, "can't read server Finished: %r");
          766                 goto Err;
          767         }
          768         if(m.tag != HFinished) {
          769                 fprint(2, "tlsClient nepm=%d\n", nepm);
          770                 tlsError(c, EUnexpectedMessage, "expected a Finished msg from server");
          771                 goto Err;
          772         }
          773 
          774         if(!finishedMatch(c, &m.u.finished)) {
          775                 tlsError(c, EHandshakeFailure, "finished verification failed");
          776                 goto Err;
          777         }
          778         msgClear(&m);
          779 
          780         if(fprint(c->ctl, "opened") < 0){
          781                 if(trace)
          782                         trace("unable to do final open: %r\n");
          783                 goto Err;
          784         }
          785         tlsSecOk(c->sec);
          786         return c;
          787 
          788 Err:
          789         free(epm);
          790         msgClear(&m);
          791         tlsConnectionFree(c);
          792         return 0;
          793 }
          794 
          795 
          796 /*================= message functions ======================== */
          797 
          798 static uchar sendbuf[9000], *sendp;
          799 
          800 static int
          801 msgSend(TlsConnection *c, Msg *m, int act)
          802 {
          803         uchar *p; /* sendp = start of new message;  p = write pointer */
          804         int nn, n, i;
          805 
          806         if(sendp == nil)
          807                 sendp = sendbuf;
          808         p = sendp;
          809         if(c->trace)
          810                 c->trace("send %s", msgPrint((char*)p, (sizeof sendbuf) - (p-sendbuf), m));
          811 
          812         p[0] = m->tag;        /* header - fill in size later */
          813         p += 4;
          814 
          815         switch(m->tag) {
          816         default:
          817                 tlsError(c, EInternalError, "can't encode a %d", m->tag);
          818                 goto Err;
          819         case HClientHello:
          820                 /* version */
          821                 put16(p, m->u.clientHello.version);
          822                 p += 2;
          823 
          824                 /* random */
          825                 memmove(p, m->u.clientHello.random, RandomSize);
          826                 p += RandomSize;
          827 
          828                 /* sid */
          829                 n = m->u.clientHello.sid->len;
          830                 assert(n < 256);
          831                 p[0] = n;
          832                 memmove(p+1, m->u.clientHello.sid->data, n);
          833                 p += n+1;
          834 
          835                 n = m->u.clientHello.ciphers->len;
          836                 assert(n > 0 && n < 200);
          837                 put16(p, n*2);
          838                 p += 2;
          839                 for(i=0; i<n; i++) {
          840                         put16(p, m->u.clientHello.ciphers->data[i]);
          841                         p += 2;
          842                 }
          843 
          844                 n = m->u.clientHello.compressors->len;
          845                 assert(n > 0);
          846                 p[0] = n;
          847                 memmove(p+1, m->u.clientHello.compressors->data, n);
          848                 p += n+1;
          849                 break;
          850         case HServerHello:
          851                 put16(p, m->u.serverHello.version);
          852                 p += 2;
          853 
          854                 /* random */
          855                 memmove(p, m->u.serverHello.random, RandomSize);
          856                 p += RandomSize;
          857 
          858                 /* sid */
          859                 n = m->u.serverHello.sid->len;
          860                 assert(n < 256);
          861                 p[0] = n;
          862                 memmove(p+1, m->u.serverHello.sid->data, n);
          863                 p += n+1;
          864 
          865                 put16(p, m->u.serverHello.cipher);
          866                 p += 2;
          867                 p[0] = m->u.serverHello.compressor;
          868                 p += 1;
          869                 break;
          870         case HServerHelloDone:
          871                 break;
          872         case HCertificate:
          873                 nn = 0;
          874                 for(i = 0; i < m->u.certificate.ncert; i++)
          875                         nn += 3 + m->u.certificate.certs[i]->len;
          876                 if(p + 3 + nn - sendbuf > sizeof(sendbuf)) {
          877                         tlsError(c, EInternalError, "output buffer too small for certificate");
          878                         goto Err;
          879                 }
          880                 put24(p, nn);
          881                 p += 3;
          882                 for(i = 0; i < m->u.certificate.ncert; i++){
          883                         put24(p, m->u.certificate.certs[i]->len);
          884                         p += 3;
          885                         memmove(p, m->u.certificate.certs[i]->data, m->u.certificate.certs[i]->len);
          886                         p += m->u.certificate.certs[i]->len;
          887                 }
          888                 break;
          889         case HClientKeyExchange:
          890                 n = m->u.clientKeyExchange.key->len;
          891                 if(c->version != SSL3Version){
          892                         put16(p, n);
          893                         p += 2;
          894                 }
          895                 memmove(p, m->u.clientKeyExchange.key->data, n);
          896                 p += n;
          897                 break;
          898         case HFinished:
          899                 memmove(p, m->u.finished.verify, m->u.finished.n);
          900                 p += m->u.finished.n;
          901                 break;
          902         }
          903 
          904         /* go back and fill in size */
          905         n = p-sendp;
          906         assert(p <= sendbuf+sizeof(sendbuf));
          907         put24(sendp+1, n-4);
          908 
          909         /* remember hash of Handshake messages */
          910         if(m->tag != HHelloRequest) {
          911                 md5(sendp, n, 0, &c->hsmd5);
          912                 sha1(sendp, n, 0, &c->hssha1);
          913         }
          914 
          915         sendp = p;
          916         if(act == AFlush){
          917                 sendp = sendbuf;
          918                 if(write(c->hand, sendbuf, p-sendbuf) < 0){
          919                         fprint(2, "write error: %r\n");
          920                         goto Err;
          921                 }
          922         }
          923         msgClear(m);
          924         return 1;
          925 Err:
          926         msgClear(m);
          927         return 0;
          928 }
          929 
          930 static uchar*
          931 tlsReadN(TlsConnection *c, int n)
          932 {
          933         uchar *p;
          934         int nn, nr;
          935 
          936         nn = c->ep - c->rp;
          937         if(nn < n){
          938                 if(c->rp != c->buf){
          939                         memmove(c->buf, c->rp, nn);
          940                         c->rp = c->buf;
          941                         c->ep = &c->buf[nn];
          942                 }
          943                 for(; nn < n; nn += nr) {
          944                         nr = read(c->hand, &c->rp[nn], n - nn);
          945                         if(nr <= 0)
          946                                 return nil;
          947                         c->ep += nr;
          948                 }
          949         }
          950         p = c->rp;
          951         c->rp += n;
          952         return p;
          953 }
          954 
          955 static int
          956 msgRecv(TlsConnection *c, Msg *m)
          957 {
          958         uchar *p;
          959         int type, n, nn, i, nsid, nrandom, nciph;
          960 
          961         for(;;) {
          962                 p = tlsReadN(c, 4);
          963                 if(p == nil)
          964                         return 0;
          965                 type = p[0];
          966                 n = get24(p+1);
          967 
          968                 if(type != HHelloRequest)
          969                         break;
          970                 if(n != 0) {
          971                         tlsError(c, EDecodeError, "invalid hello request during handshake");
          972                         return 0;
          973                 }
          974         }
          975 
          976         if(n > sizeof(c->buf)) {
          977                 tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->buf));
          978                 return 0;
          979         }
          980 
          981         if(type == HSSL2ClientHello){
          982                 /* Cope with an SSL3 ClientHello expressed in SSL2 record format.
          983                         This is sent by some clients that we must interoperate
          984                         with, such as Java's JSSE and Microsoft's Internet Explorer. */
          985                 p = tlsReadN(c, n);
          986                 if(p == nil)
          987                         return 0;
          988                 md5(p, n, 0, &c->hsmd5);
          989                 sha1(p, n, 0, &c->hssha1);
          990                 m->tag = HClientHello;
          991                 if(n < 22)
          992                         goto Short;
          993                 m->u.clientHello.version = get16(p+1);
          994                 p += 3;
          995                 n -= 3;
          996                 nn = get16(p); /* cipher_spec_len */
          997                 nsid = get16(p + 2);
          998                 nrandom = get16(p + 4);
          999                 p += 6;
         1000                 n -= 6;
         1001                 if(nsid != 0         /* no sid's, since shouldn't restart using ssl2 header */
         1002                                 || nrandom < 16 || nn % 3)
         1003                         goto Err;
         1004                 if(c->trace && (n - nrandom != nn))
         1005                         c->trace("n-nrandom!=nn: n=%d nrandom=%d nn=%d\n", n, nrandom, nn);
         1006                 /* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */
         1007                 nciph = 0;
         1008                 for(i = 0; i < nn; i += 3)
         1009                         if(p[i] == 0)
         1010                                 nciph++;
         1011                 m->u.clientHello.ciphers = newints(nciph);
         1012                 nciph = 0;
         1013                 for(i = 0; i < nn; i += 3)
         1014                         if(p[i] == 0)
         1015                                 m->u.clientHello.ciphers->data[nciph++] = get16(&p[i + 1]);
         1016                 p += nn;
         1017                 m->u.clientHello.sid = makebytes(nil, 0);
         1018                 if(nrandom > RandomSize)
         1019                         nrandom = RandomSize;
         1020                 memset(m->u.clientHello.random, 0, RandomSize - nrandom);
         1021                 memmove(&m->u.clientHello.random[RandomSize - nrandom], p, nrandom);
         1022                 m->u.clientHello.compressors = newbytes(1);
         1023                 m->u.clientHello.compressors->data[0] = CompressionNull;
         1024                 goto Ok;
         1025         }
         1026 
         1027         md5(p, 4, 0, &c->hsmd5);
         1028         sha1(p, 4, 0, &c->hssha1);
         1029 
         1030         p = tlsReadN(c, n);
         1031         if(p == nil)
         1032                 return 0;
         1033 
         1034         md5(p, n, 0, &c->hsmd5);
         1035         sha1(p, n, 0, &c->hssha1);
         1036 
         1037         m->tag = type;
         1038 
         1039         switch(type) {
         1040         default:
         1041                 tlsError(c, EUnexpectedMessage, "can't decode a %d", type);
         1042                 goto Err;
         1043         case HClientHello:
         1044                 if(n < 2)
         1045                         goto Short;
         1046                 m->u.clientHello.version = get16(p);
         1047                 p += 2;
         1048                 n -= 2;
         1049 
         1050                 if(n < RandomSize)
         1051                         goto Short;
         1052                 memmove(m->u.clientHello.random, p, RandomSize);
         1053                 p += RandomSize;
         1054                 n -= RandomSize;
         1055                 if(n < 1 || n < p[0]+1)
         1056                         goto Short;
         1057                 m->u.clientHello.sid = makebytes(p+1, p[0]);
         1058                 p += m->u.clientHello.sid->len+1;
         1059                 n -= m->u.clientHello.sid->len+1;
         1060 
         1061                 if(n < 2)
         1062                         goto Short;
         1063                 nn = get16(p);
         1064                 p += 2;
         1065                 n -= 2;
         1066 
         1067                 if((nn & 1) || n < nn || nn < 2)
         1068                         goto Short;
         1069                 m->u.clientHello.ciphers = newints(nn >> 1);
         1070                 for(i = 0; i < nn; i += 2)
         1071                         m->u.clientHello.ciphers->data[i >> 1] = get16(&p[i]);
         1072                 p += nn;
         1073                 n -= nn;
         1074 
         1075                 if(n < 1 || n < p[0]+1 || p[0] == 0)
         1076                         goto Short;
         1077                 nn = p[0];
         1078                 m->u.clientHello.compressors = newbytes(nn);
         1079                 memmove(m->u.clientHello.compressors->data, p+1, nn);
         1080                 n -= nn + 1;
         1081                 break;
         1082         case HServerHello:
         1083                 if(n < 2)
         1084                         goto Short;
         1085                 m->u.serverHello.version = get16(p);
         1086                 p += 2;
         1087                 n -= 2;
         1088 
         1089                 if(n < RandomSize)
         1090                         goto Short;
         1091                 memmove(m->u.serverHello.random, p, RandomSize);
         1092                 p += RandomSize;
         1093                 n -= RandomSize;
         1094 
         1095                 if(n < 1 || n < p[0]+1)
         1096                         goto Short;
         1097                 m->u.serverHello.sid = makebytes(p+1, p[0]);
         1098                 p += m->u.serverHello.sid->len+1;
         1099                 n -= m->u.serverHello.sid->len+1;
         1100 
         1101                 if(n < 3)
         1102                         goto Short;
         1103                 m->u.serverHello.cipher = get16(p);
         1104                 m->u.serverHello.compressor = p[2];
         1105                 n -= 3;
         1106                 break;
         1107         case HCertificate:
         1108                 if(n < 3)
         1109                         goto Short;
         1110                 nn = get24(p);
         1111                 p += 3;
         1112                 n -= 3;
         1113                 if(n != nn)
         1114                         goto Short;
         1115                 /* certs */
         1116                 i = 0;
         1117                 while(n > 0) {
         1118                         if(n < 3)
         1119                                 goto Short;
         1120                         nn = get24(p);
         1121                         p += 3;
         1122                         n -= 3;
         1123                         if(nn > n)
         1124                                 goto Short;
         1125                         m->u.certificate.ncert = i+1;
         1126                         m->u.certificate.certs = erealloc(m->u.certificate.certs, (i+1)*sizeof(Bytes));
         1127                         m->u.certificate.certs[i] = makebytes(p, nn);
         1128                         p += nn;
         1129                         n -= nn;
         1130                         i++;
         1131                 }
         1132                 break;
         1133         case HCertificateRequest:
         1134                 if(n < 2)
         1135                         goto Short;
         1136                 nn = get16(p);
         1137                 p += 2;
         1138                 n -= 2;
         1139                 if(nn < 1 || nn > n)
         1140                         goto Short;
         1141                 m->u.certificateRequest.types = makebytes(p, nn);
         1142                 nn = get24(p);
         1143                 p += 3;
         1144                 n -= 3;
         1145                 if(nn == 0 || n != nn)
         1146                         goto Short;
         1147                 /* cas */
         1148                 i = 0;
         1149                 while(n > 0) {
         1150                         if(n < 2)
         1151                                 goto Short;
         1152                         nn = get16(p);
         1153                         p += 2;
         1154                         n -= 2;
         1155                         if(nn < 1 || nn > n)
         1156                                 goto Short;
         1157                         m->u.certificateRequest.nca = i+1;
         1158                         m->u.certificateRequest.cas = erealloc(m->u.certificateRequest.cas, (i+1)*sizeof(Bytes));
         1159                         m->u.certificateRequest.cas[i] = makebytes(p, nn);
         1160                         p += nn;
         1161                         n -= nn;
         1162                         i++;
         1163                 }
         1164                 break;
         1165         case HServerHelloDone:
         1166                 break;
         1167         case HClientKeyExchange:
         1168                 /*
         1169                  * this message depends upon the encryption selected
         1170                  * assume rsa.
         1171                  */
         1172                 if(c->version == SSL3Version)
         1173                         nn = n;
         1174                 else{
         1175                         if(n < 2)
         1176                                 goto Short;
         1177                         nn = get16(p);
         1178                         p += 2;
         1179                         n -= 2;
         1180                 }
         1181                 if(n < nn)
         1182                         goto Short;
         1183                 m->u.clientKeyExchange.key = makebytes(p, nn);
         1184                 n -= nn;
         1185                 break;
         1186         case HFinished:
         1187                 m->u.finished.n = c->finished.n;
         1188                 if(n < m->u.finished.n)
         1189                         goto Short;
         1190                 memmove(m->u.finished.verify, p, m->u.finished.n);
         1191                 n -= m->u.finished.n;
         1192                 break;
         1193         }
         1194 
         1195         if(type != HClientHello && n != 0)
         1196                 goto Short;
         1197 Ok:
         1198         if(c->trace){
         1199                 char buf[8000];
         1200                 c->trace("recv %s", msgPrint(buf, sizeof buf, m));
         1201         }
         1202         return 1;
         1203 Short:
         1204         tlsError(c, EDecodeError, "handshake message has invalid length");
         1205 Err:
         1206         msgClear(m);
         1207         return 0;
         1208 }
         1209 
         1210 static void
         1211 msgClear(Msg *m)
         1212 {
         1213         int i;
         1214 
         1215         switch(m->tag) {
         1216         default:
         1217                 sysfatal("msgClear: unknown message type: %d\n", m->tag);
         1218         case HHelloRequest:
         1219                 break;
         1220         case HClientHello:
         1221                 freebytes(m->u.clientHello.sid);
         1222                 freeints(m->u.clientHello.ciphers);
         1223                 freebytes(m->u.clientHello.compressors);
         1224                 break;
         1225         case HServerHello:
         1226                 freebytes(m->u.clientHello.sid);
         1227                 break;
         1228         case HCertificate:
         1229                 for(i=0; i<m->u.certificate.ncert; i++)
         1230                         freebytes(m->u.certificate.certs[i]);
         1231                 free(m->u.certificate.certs);
         1232                 break;
         1233         case HCertificateRequest:
         1234                 freebytes(m->u.certificateRequest.types);
         1235                 for(i=0; i<m->u.certificateRequest.nca; i++)
         1236                         freebytes(m->u.certificateRequest.cas[i]);
         1237                 free(m->u.certificateRequest.cas);
         1238                 break;
         1239         case HServerHelloDone:
         1240                 break;
         1241         case HClientKeyExchange:
         1242                 freebytes(m->u.clientKeyExchange.key);
         1243                 break;
         1244         case HFinished:
         1245                 break;
         1246         }
         1247         memset(m, 0, sizeof(Msg));
         1248 }
         1249 
         1250 static char *
         1251 bytesPrint(char *bs, char *be, char *s0, Bytes *b, char *s1)
         1252 {
         1253         int i;
         1254 
         1255         if(s0)
         1256                 bs = seprint(bs, be, "%s", s0);
         1257         bs = seprint(bs, be, "[");
         1258         if(b == nil)
         1259                 bs = seprint(bs, be, "nil");
         1260         else
         1261                 for(i=0; i<b->len; i++)
         1262                         bs = seprint(bs, be, "%.2x ", b->data[i]);
         1263         bs = seprint(bs, be, "]");
         1264         if(s1)
         1265                 bs = seprint(bs, be, "%s", s1);
         1266         return bs;
         1267 }
         1268 
         1269 static char *
         1270 intsPrint(char *bs, char *be, char *s0, Ints *b, char *s1)
         1271 {
         1272         int i;
         1273 
         1274         if(s0)
         1275                 bs = seprint(bs, be, "%s", s0);
         1276         bs = seprint(bs, be, "[");
         1277         if(b == nil)
         1278                 bs = seprint(bs, be, "nil");
         1279         else
         1280                 for(i=0; i<b->len; i++)
         1281                         bs = seprint(bs, be, "%x ", b->data[i]);
         1282         bs = seprint(bs, be, "]");
         1283         if(s1)
         1284                 bs = seprint(bs, be, "%s", s1);
         1285         return bs;
         1286 }
         1287 
         1288 static char*
         1289 msgPrint(char *buf, int n, Msg *m)
         1290 {
         1291         int i;
         1292         char *bs = buf, *be = buf+n;
         1293 
         1294         switch(m->tag) {
         1295         default:
         1296                 bs = seprint(bs, be, "unknown %d\n", m->tag);
         1297                 break;
         1298         case HClientHello:
         1299                 bs = seprint(bs, be, "ClientHello\n");
         1300                 bs = seprint(bs, be, "\tversion: %.4x\n", m->u.clientHello.version);
         1301                 bs = seprint(bs, be, "\trandom: ");
         1302                 for(i=0; i<RandomSize; i++)
         1303                         bs = seprint(bs, be, "%.2x", m->u.clientHello.random[i]);
         1304                 bs = seprint(bs, be, "\n");
         1305                 bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n");
         1306                 bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n");
         1307                 bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n");
         1308                 break;
         1309         case HServerHello:
         1310                 bs = seprint(bs, be, "ServerHello\n");
         1311                 bs = seprint(bs, be, "\tversion: %.4x\n", m->u.serverHello.version);
         1312                 bs = seprint(bs, be, "\trandom: ");
         1313                 for(i=0; i<RandomSize; i++)
         1314                         bs = seprint(bs, be, "%.2x", m->u.serverHello.random[i]);
         1315                 bs = seprint(bs, be, "\n");
         1316                 bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n");
         1317                 bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher);
         1318                 bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor);
         1319                 break;
         1320         case HCertificate:
         1321                 bs = seprint(bs, be, "Certificate\n");
         1322                 for(i=0; i<m->u.certificate.ncert; i++)
         1323                         bs = bytesPrint(bs, be, "\t", m->u.certificate.certs[i], "\n");
         1324                 break;
         1325         case HCertificateRequest:
         1326                 bs = seprint(bs, be, "CertificateRequest\n");
         1327                 bs = bytesPrint(bs, be, "\ttypes: ", m->u.certificateRequest.types, "\n");
         1328                 bs = seprint(bs, be, "\tcertificateauthorities\n");
         1329                 for(i=0; i<m->u.certificateRequest.nca; i++)
         1330                         bs = bytesPrint(bs, be, "\t\t", m->u.certificateRequest.cas[i], "\n");
         1331                 break;
         1332         case HServerHelloDone:
         1333                 bs = seprint(bs, be, "ServerHelloDone\n");
         1334                 break;
         1335         case HClientKeyExchange:
         1336                 bs = seprint(bs, be, "HClientKeyExchange\n");
         1337                 bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
         1338                 break;
         1339         case HFinished:
         1340                 bs = seprint(bs, be, "HFinished\n");
         1341                 for(i=0; i<m->u.finished.n; i++)
         1342                         bs = seprint(bs, be, "%.2x", m->u.finished.verify[i]);
         1343                 bs = seprint(bs, be, "\n");
         1344                 break;
         1345         }
         1346         USED(bs);
         1347         return buf;
         1348 }
         1349 
         1350 static void
         1351 tlsError(TlsConnection *c, int err, char *fmt, ...)
         1352 {
         1353         char msg[512];
         1354         va_list arg;
         1355 
         1356         va_start(arg, fmt);
         1357         vseprint(msg, msg+sizeof(msg), fmt, arg);
         1358         va_end(arg);
         1359         if(c->trace)
         1360                 c->trace("tlsError: %s\n", msg);
         1361         else if(c->erred)
         1362                 fprint(2, "double error: %r, %s", msg);
         1363         else
         1364                 werrstr("tls: local %s", msg);
         1365         c->erred = 1;
         1366         fprint(c->ctl, "alert %d", err);
         1367 }
         1368 
         1369 /* commit to specific version number */
         1370 static int
         1371 setVersion(TlsConnection *c, int version)
         1372 {
         1373         if(c->verset || version > MaxProtoVersion || version < MinProtoVersion)
         1374                 return -1;
         1375         if(version > c->version)
         1376                 version = c->version;
         1377         if(version == SSL3Version) {
         1378                 c->version = version;
         1379                 c->finished.n = SSL3FinishedLen;
         1380         }else if(version == TLSVersion){
         1381                 c->version = version;
         1382                 c->finished.n = TLSFinishedLen;
         1383         }else
         1384                 return -1;
         1385         c->verset = 1;
         1386         return fprint(c->ctl, "version 0x%x", version);
         1387 }
         1388 
         1389 /* confirm that received Finished message matches the expected value */
         1390 static int
         1391 finishedMatch(TlsConnection *c, Finished *f)
         1392 {
         1393         return memcmp(f->verify, c->finished.verify, f->n) == 0;
         1394 }
         1395 
         1396 /* free memory associated with TlsConnection struct */
         1397 /*                (but don't close the TLS channel itself) */
         1398 static void
         1399 tlsConnectionFree(TlsConnection *c)
         1400 {
         1401         tlsSecClose(c->sec);
         1402         freebytes(c->sid);
         1403         freebytes(c->cert);
         1404         memset(c, 0, sizeof(*c));
         1405         free(c);
         1406 }
         1407 
         1408 
         1409 /*================= cipher choices ======================== */
         1410 
         1411 static int weakCipher[CipherMax] =
         1412 {
         1413         1,        /* TLS_NULL_WITH_NULL_NULL */
         1414         1,        /* TLS_RSA_WITH_NULL_MD5 */
         1415         1,        /* TLS_RSA_WITH_NULL_SHA */
         1416         1,        /* TLS_RSA_EXPORT_WITH_RC4_40_MD5 */
         1417         0,        /* TLS_RSA_WITH_RC4_128_MD5 */
         1418         0,        /* TLS_RSA_WITH_RC4_128_SHA */
         1419         1,        /* TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 */
         1420         0,        /* TLS_RSA_WITH_IDEA_CBC_SHA */
         1421         1,        /* TLS_RSA_EXPORT_WITH_DES40_CBC_SHA */
         1422         0,        /* TLS_RSA_WITH_DES_CBC_SHA */
         1423         0,        /* TLS_RSA_WITH_3DES_EDE_CBC_SHA */
         1424         1,        /* TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA */
         1425         0,        /* TLS_DH_DSS_WITH_DES_CBC_SHA */
         1426         0,        /* TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA */
         1427         1,        /* TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA */
         1428         0,        /* TLS_DH_RSA_WITH_DES_CBC_SHA */
         1429         0,        /* TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA */
         1430         1,        /* TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA */
         1431         0,        /* TLS_DHE_DSS_WITH_DES_CBC_SHA */
         1432         0,        /* TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA */
         1433         1,        /* TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA */
         1434         0,        /* TLS_DHE_RSA_WITH_DES_CBC_SHA */
         1435         0,        /* TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA */
         1436         1,        /* TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 */
         1437         1,        /* TLS_DH_anon_WITH_RC4_128_MD5 */
         1438         1,        /* TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA */
         1439         1,        /* TLS_DH_anon_WITH_DES_CBC_SHA */
         1440         1,        /* TLS_DH_anon_WITH_3DES_EDE_CBC_SHA */
         1441 };
         1442 
         1443 static int
         1444 setAlgs(TlsConnection *c, int a)
         1445 {
         1446         int i;
         1447 
         1448         for(i = 0; i < nelem(cipherAlgs); i++){
         1449                 if(cipherAlgs[i].tlsid == a){
         1450                         c->enc = cipherAlgs[i].enc;
         1451                         c->digest = cipherAlgs[i].digest;
         1452                         c->nsecret = cipherAlgs[i].nsecret;
         1453                         if(c->nsecret > MaxKeyData)
         1454                                 return 0;
         1455                         return 1;
         1456                 }
         1457         }
         1458         return 0;
         1459 }
         1460 
         1461 static int
         1462 okCipher(Ints *cv)
         1463 {
         1464         int weak, i, j, c;
         1465 
         1466         weak = 1;
         1467         for(i = 0; i < cv->len; i++) {
         1468                 c = cv->data[i];
         1469                 if(c >= CipherMax)
         1470                         weak = 0;
         1471                 else
         1472                         weak &= weakCipher[c];
         1473                 for(j = 0; j < nelem(cipherAlgs); j++)
         1474                         if(cipherAlgs[j].ok && cipherAlgs[j].tlsid == c)
         1475                                 return c;
         1476         }
         1477         if(weak)
         1478                 return -2;
         1479         return -1;
         1480 }
         1481 
         1482 static int
         1483 okCompression(Bytes *cv)
         1484 {
         1485         int i, j, c;
         1486 
         1487         for(i = 0; i < cv->len; i++) {
         1488                 c = cv->data[i];
         1489                 for(j = 0; j < nelem(compressors); j++) {
         1490                         if(compressors[j] == c)
         1491                                 return c;
         1492                 }
         1493         }
         1494         return -1;
         1495 }
         1496 
         1497 static Lock        ciphLock;
         1498 static int        nciphers;
         1499 
         1500 static int
         1501 initCiphers(void)
         1502 {
         1503         enum {MaxAlgF = 1024, MaxAlgs = 10};
         1504         char s[MaxAlgF], *flds[MaxAlgs];
         1505         int i, j, n, ok;
         1506 
         1507         lock(&ciphLock);
         1508         if(nciphers){
         1509                 unlock(&ciphLock);
         1510                 return nciphers;
         1511         }
         1512         j = open("#a/tls/encalgs", OREAD);
         1513         if(j < 0){
         1514                 werrstr("can't open #a/tls/encalgs: %r");
         1515                 return 0;
         1516         }
         1517         n = read(j, s, MaxAlgF-1);
         1518         close(j);
         1519         if(n <= 0){
         1520                 werrstr("nothing in #a/tls/encalgs: %r");
         1521                 return 0;
         1522         }
         1523         s[n] = 0;
         1524         n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
         1525         for(i = 0; i < nelem(cipherAlgs); i++){
         1526                 ok = 0;
         1527                 for(j = 0; j < n; j++){
         1528                         if(strcmp(cipherAlgs[i].enc, flds[j]) == 0){
         1529                                 ok = 1;
         1530                                 break;
         1531                         }
         1532                 }
         1533                 cipherAlgs[i].ok = ok;
         1534         }
         1535 
         1536         j = open("#a/tls/hashalgs", OREAD);
         1537         if(j < 0){
         1538                 werrstr("can't open #a/tls/hashalgs: %r");
         1539                 return 0;
         1540         }
         1541         n = read(j, s, MaxAlgF-1);
         1542         close(j);
         1543         if(n <= 0){
         1544                 werrstr("nothing in #a/tls/hashalgs: %r");
         1545                 return 0;
         1546         }
         1547         s[n] = 0;
         1548         n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
         1549         for(i = 0; i < nelem(cipherAlgs); i++){
         1550                 ok = 0;
         1551                 for(j = 0; j < n; j++){
         1552                         if(strcmp(cipherAlgs[i].digest, flds[j]) == 0){
         1553                                 ok = 1;
         1554                                 break;
         1555                         }
         1556                 }
         1557                 cipherAlgs[i].ok &= ok;
         1558                 if(cipherAlgs[i].ok)
         1559                         nciphers++;
         1560         }
         1561         unlock(&ciphLock);
         1562         return nciphers;
         1563 }
         1564 
         1565 static Ints*
         1566 makeciphers(void)
         1567 {
         1568         Ints *is;
         1569         int i, j;
         1570 
         1571         is = newints(nciphers);
         1572         j = 0;
         1573         for(i = 0; i < nelem(cipherAlgs); i++){
         1574                 if(cipherAlgs[i].ok)
         1575                         is->data[j++] = cipherAlgs[i].tlsid;
         1576         }
         1577         return is;
         1578 }
         1579 
         1580 
         1581 
         1582 /*================= security functions ======================== */
         1583 
         1584 /* given X.509 certificate, set up connection to factotum */
         1585 /*        for using corresponding private key */
         1586 static AuthRpc*
         1587 factotum_rsa_open(uchar *cert, int certlen)
         1588 {
         1589         char *s;
         1590         mpint *pub = nil;
         1591         RSApub *rsapub;
         1592         AuthRpc *rpc;
         1593 
         1594         if((rpc = auth_allocrpc()) == nil){
         1595                 return nil;
         1596         }
         1597         s = "proto=rsa service=tls role=client";
         1598         if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){
         1599                 factotum_rsa_close(rpc);
         1600                 return nil;
         1601         }
         1602 
         1603         /* roll factotum keyring around to match certificate */
         1604         rsapub = X509toRSApub(cert, certlen, nil, 0);
         1605         while(1){
         1606                 if(auth_rpc(rpc, "read", nil, 0) != ARok){
         1607                         factotum_rsa_close(rpc);
         1608                         rpc = nil;
         1609                         goto done;
         1610                 }
         1611                 pub = strtomp(rpc->arg, nil, 16, nil);
         1612                 assert(pub != nil);
         1613                 if(mpcmp(pub,rsapub->n) == 0)
         1614                         break;
         1615         }
         1616 done:
         1617         mpfree(pub);
         1618         rsapubfree(rsapub);
         1619         return rpc;
         1620 }
         1621 
         1622 static mpint*
         1623 factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher)
         1624 {
         1625         char *p;
         1626         int rv;
         1627 
         1628         if((p = mptoa(cipher, 16, nil, 0)) == nil)
         1629                 return nil;
         1630         rv = auth_rpc(rpc, "write", p, strlen(p));
         1631         free(p);
         1632         if(rv != ARok || auth_rpc(rpc, "read", nil, 0) != ARok)
         1633                 return nil;
         1634         mpfree(cipher);
         1635         return strtomp(rpc->arg, nil, 16, nil);
         1636 }
         1637 
         1638 static void
         1639 factotum_rsa_close(AuthRpc*rpc)
         1640 {
         1641         if(!rpc)
         1642                 return;
         1643         close(rpc->afd);
         1644         auth_freerpc(rpc);
         1645 }
         1646 
         1647 static void
         1648 tlsPmd5(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
         1649 {
         1650         uchar ai[MD5dlen], tmp[MD5dlen];
         1651         int i, n;
         1652         MD5state *s;
         1653 
         1654         /* generate a1 */
         1655         s = hmac_md5(label, nlabel, key, nkey, nil, nil);
         1656         s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
         1657         hmac_md5(seed1, nseed1, key, nkey, ai, s);
         1658 
         1659         while(nbuf > 0) {
         1660                 s = hmac_md5(ai, MD5dlen, key, nkey, nil, nil);
         1661                 s = hmac_md5(label, nlabel, key, nkey, nil, s);
         1662                 s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
         1663                 hmac_md5(seed1, nseed1, key, nkey, tmp, s);
         1664                 n = MD5dlen;
         1665                 if(n > nbuf)
         1666                         n = nbuf;
         1667                 for(i = 0; i < n; i++)
         1668                         buf[i] ^= tmp[i];
         1669                 buf += n;
         1670                 nbuf -= n;
         1671                 hmac_md5(ai, MD5dlen, key, nkey, tmp, nil);
         1672                 memmove(ai, tmp, MD5dlen);
         1673         }
         1674 }
         1675 
         1676 static void
         1677 tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
         1678 {
         1679         uchar ai[SHA1dlen], tmp[SHA1dlen];
         1680         int i, n;
         1681         SHAstate *s;
         1682 
         1683         /* generate a1 */
         1684         s = hmac_sha1(label, nlabel, key, nkey, nil, nil);
         1685         s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
         1686         hmac_sha1(seed1, nseed1, key, nkey, ai, s);
         1687 
         1688         while(nbuf > 0) {
         1689                 s = hmac_sha1(ai, SHA1dlen, key, nkey, nil, nil);
         1690                 s = hmac_sha1(label, nlabel, key, nkey, nil, s);
         1691                 s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
         1692                 hmac_sha1(seed1, nseed1, key, nkey, tmp, s);
         1693                 n = SHA1dlen;
         1694                 if(n > nbuf)
         1695                         n = nbuf;
         1696                 for(i = 0; i < n; i++)
         1697                         buf[i] ^= tmp[i];
         1698                 buf += n;
         1699                 nbuf -= n;
         1700                 hmac_sha1(ai, SHA1dlen, key, nkey, tmp, nil);
         1701                 memmove(ai, tmp, SHA1dlen);
         1702         }
         1703 }
         1704 
         1705 /* fill buf with md5(args)^sha1(args) */
         1706 static void
         1707 tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
         1708 {
         1709         int i;
         1710         int nlabel = strlen(label);
         1711         int n = (nkey + 1) >> 1;
         1712 
         1713         for(i = 0; i < nbuf; i++)
         1714                 buf[i] = 0;
         1715         tlsPmd5(buf, nbuf, key, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
         1716         tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
         1717 }
         1718 
         1719 /*
         1720  * for setting server session id's
         1721  */
         1722 static Lock        sidLock;
         1723 static long        maxSid = 1;
         1724 
         1725 /* the keys are verified to have the same public components
         1726  * and to function correctly with pkcs 1 encryption and decryption. */
         1727 static TlsSec*
         1728 tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom)
         1729 {
         1730         TlsSec *sec = emalloc(sizeof(*sec));
         1731 
         1732         USED(csid); USED(ncsid);  /* ignore csid for now */
         1733 
         1734         memmove(sec->crandom, crandom, RandomSize);
         1735         sec->clientVers = cvers;
         1736 
         1737         put32(sec->srandom, time(0));
         1738         genrandom(sec->srandom+4, RandomSize-4);
         1739         memmove(srandom, sec->srandom, RandomSize);
         1740 
         1741         /*
         1742          * make up a unique sid: use our pid, and and incrementing id
         1743          * can signal no sid by setting nssid to 0.
         1744          */
         1745         memset(ssid, 0, SidSize);
         1746         put32(ssid, getpid());
         1747         lock(&sidLock);
         1748         put32(ssid+4, maxSid++);
         1749         unlock(&sidLock);
         1750         *nssid = SidSize;
         1751         return sec;
         1752 }
         1753 
         1754 static int
         1755 tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd)
         1756 {
         1757         if(epm != nil){
         1758                 if(setVers(sec, vers) < 0)
         1759                         goto Err;
         1760                 serverMasterSecret(sec, epm, nepm);
         1761         }else if(sec->vers != vers){
         1762                 werrstr("mismatched session versions");
         1763                 goto Err;
         1764         }
         1765         setSecrets(sec, kd, nkd);
         1766         return 0;
         1767 Err:
         1768         sec->ok = -1;
         1769         return -1;
         1770 }
         1771 
         1772 static TlsSec*
         1773 tlsSecInitc(int cvers, uchar *crandom)
         1774 {
         1775         TlsSec *sec = emalloc(sizeof(*sec));
         1776         sec->clientVers = cvers;
         1777         put32(sec->crandom, time(0));
         1778         genrandom(sec->crandom+4, RandomSize-4);
         1779         memmove(crandom, sec->crandom, RandomSize);
         1780         return sec;
         1781 }
         1782 
         1783 static int
         1784 tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd)
         1785 {
         1786         RSApub *pub;
         1787 
         1788         pub = nil;
         1789 
         1790         USED(sid);
         1791         USED(nsid);
         1792 
         1793         memmove(sec->srandom, srandom, RandomSize);
         1794 
         1795         if(setVers(sec, vers) < 0)
         1796                 goto Err;
         1797 
         1798         pub = X509toRSApub(cert, ncert, nil, 0);
         1799         if(pub == nil){
         1800                 werrstr("invalid x509/rsa certificate");
         1801                 goto Err;
         1802         }
         1803         if(clientMasterSecret(sec, pub, epm, nepm) < 0)
         1804                 goto Err;
         1805         rsapubfree(pub);
         1806         setSecrets(sec, kd, nkd);
         1807         return 0;
         1808 
         1809 Err:
         1810         if(pub != nil)
         1811                 rsapubfree(pub);
         1812         sec->ok = -1;
         1813         return -1;
         1814 }
         1815 
         1816 static int
         1817 tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient)
         1818 {
         1819         if(sec->nfin != nfin){
         1820                 sec->ok = -1;
         1821                 werrstr("invalid finished exchange");
         1822                 return -1;
         1823         }
         1824         md5.malloced = 0;
         1825         sha1.malloced = 0;
         1826         (*sec->setFinished)(sec, md5, sha1, fin, isclient);
         1827         return 1;
         1828 }
         1829 
         1830 static void
         1831 tlsSecOk(TlsSec *sec)
         1832 {
         1833         if(sec->ok == 0)
         1834                 sec->ok = 1;
         1835 }
         1836 
         1837 /*
         1838 static void
         1839 tlsSecKill(TlsSec *sec)
         1840 {
         1841         if(!sec)
         1842                 return;
         1843         factotum_rsa_close(sec->rpc);
         1844         sec->ok = -1;
         1845 }
         1846 */
         1847 
         1848 static void
         1849 tlsSecClose(TlsSec *sec)
         1850 {
         1851         if(!sec)
         1852                 return;
         1853         factotum_rsa_close(sec->rpc);
         1854         free(sec->server);
         1855         free(sec);
         1856 }
         1857 
         1858 static int
         1859 setVers(TlsSec *sec, int v)
         1860 {
         1861         if(v == SSL3Version){
         1862                 sec->setFinished = sslSetFinished;
         1863                 sec->nfin = SSL3FinishedLen;
         1864                 sec->prf = sslPRF;
         1865         }else if(v == TLSVersion){
         1866                 sec->setFinished = tlsSetFinished;
         1867                 sec->nfin = TLSFinishedLen;
         1868                 sec->prf = tlsPRF;
         1869         }else{
         1870                 werrstr("invalid version");
         1871                 return -1;
         1872         }
         1873         sec->vers = v;
         1874         return 0;
         1875 }
         1876 
         1877 /*
         1878  * generate secret keys from the master secret.
         1879  *
         1880  * different crypto selections will require different amounts
         1881  * of key expansion and use of key expansion data,
         1882  * but it's all generated using the same function.
         1883  */
         1884 static void
         1885 setSecrets(TlsSec *sec, uchar *kd, int nkd)
         1886 {
         1887         (*sec->prf)(kd, nkd, sec->sec, MasterSecretSize, "key expansion",
         1888                         sec->srandom, RandomSize, sec->crandom, RandomSize);
         1889 }
         1890 
         1891 /*
         1892  * set the master secret from the pre-master secret.
         1893  */
         1894 static void
         1895 setMasterSecret(TlsSec *sec, Bytes *pm)
         1896 {
         1897         (*sec->prf)(sec->sec, MasterSecretSize, pm->data, MasterSecretSize, "master secret",
         1898                         sec->crandom, RandomSize, sec->srandom, RandomSize);
         1899 }
         1900 
         1901 static void
         1902 serverMasterSecret(TlsSec *sec, uchar *epm, int nepm)
         1903 {
         1904         Bytes *pm;
         1905 
         1906         pm = pkcs1_decrypt(sec, epm, nepm);
         1907 
         1908         /* if the client messed up, just continue as if everything is ok, */
         1909         /* to prevent attacks to check for correctly formatted messages. */
         1910         /* Hence the fprint(2,) can't be replaced by tlsError(), which sends an Alert msg to the client. */
         1911         if(sec->ok < 0 || pm == nil || get16(pm->data) != sec->clientVers){
         1912                 fprint(2, "serverMasterSecret failed ok=%d pm=%p pmvers=%x cvers=%x nepm=%d\n",
         1913                         sec->ok, pm, pm ? get16(pm->data) : -1, sec->clientVers, nepm);
         1914                 sec->ok = -1;
         1915                 if(pm != nil)
         1916                         freebytes(pm);
         1917                 pm = newbytes(MasterSecretSize);
         1918                 genrandom(pm->data, MasterSecretSize);
         1919         }
         1920         setMasterSecret(sec, pm);
         1921         memset(pm->data, 0, pm->len);
         1922         freebytes(pm);
         1923 }
         1924 
         1925 static int
         1926 clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm)
         1927 {
         1928         Bytes *pm, *key;
         1929 
         1930         pm = newbytes(MasterSecretSize);
         1931         put16(pm->data, sec->clientVers);
         1932         genrandom(pm->data+2, MasterSecretSize - 2);
         1933 
         1934         setMasterSecret(sec, pm);
         1935 
         1936         key = pkcs1_encrypt(pm, pub, 2);
         1937         memset(pm->data, 0, pm->len);
         1938         freebytes(pm);
         1939         if(key == nil){
         1940                 werrstr("tls pkcs1_encrypt failed");
         1941                 return -1;
         1942         }
         1943 
         1944         *nepm = key->len;
         1945         *epm = malloc(*nepm);
         1946         if(*epm == nil){
         1947                 freebytes(key);
         1948                 werrstr("out of memory");
         1949                 return -1;
         1950         }
         1951         memmove(*epm, key->data, *nepm);
         1952 
         1953         freebytes(key);
         1954 
         1955         return 1;
         1956 }
         1957 
         1958 static void
         1959 sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
         1960 {
         1961         DigestState *s;
         1962         uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
         1963         char *label;
         1964 
         1965         if(isClient)
         1966                 label = "CLNT";
         1967         else
         1968                 label = "SRVR";
         1969 
         1970         md5((uchar*)label, 4, nil, &hsmd5);
         1971         md5(sec->sec, MasterSecretSize, nil, &hsmd5);
         1972         memset(pad, 0x36, 48);
         1973         md5(pad, 48, nil, &hsmd5);
         1974         md5(nil, 0, h0, &hsmd5);
         1975         memset(pad, 0x5C, 48);
         1976         s = md5(sec->sec, MasterSecretSize, nil, nil);
         1977         s = md5(pad, 48, nil, s);
         1978         md5(h0, MD5dlen, finished, s);
         1979 
         1980         sha1((uchar*)label, 4, nil, &hssha1);
         1981         sha1(sec->sec, MasterSecretSize, nil, &hssha1);
         1982         memset(pad, 0x36, 40);
         1983         sha1(pad, 40, nil, &hssha1);
         1984         sha1(nil, 0, h1, &hssha1);
         1985         memset(pad, 0x5C, 40);
         1986         s = sha1(sec->sec, MasterSecretSize, nil, nil);
         1987         s = sha1(pad, 40, nil, s);
         1988         sha1(h1, SHA1dlen, finished + MD5dlen, s);
         1989 }
         1990 
         1991 /* fill "finished" arg with md5(args)^sha1(args) */
         1992 static void
         1993 tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
         1994 {
         1995         uchar h0[MD5dlen], h1[SHA1dlen];
         1996         char *label;
         1997 
         1998         /* get current hash value, but allow further messages to be hashed in */
         1999         md5(nil, 0, h0, &hsmd5);
         2000         sha1(nil, 0, h1, &hssha1);
         2001 
         2002         if(isClient)
         2003                 label = "client finished";
         2004         else
         2005                 label = "server finished";
         2006         tlsPRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
         2007 }
         2008 
         2009 static void
         2010 sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
         2011 {
         2012         DigestState *s;
         2013         uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26];
         2014         int i, n, len;
         2015 
         2016         USED(label);
         2017         len = 1;
         2018         while(nbuf > 0){
         2019                 if(len > 26)
         2020                         return;
         2021                 for(i = 0; i < len; i++)
         2022                         tmp[i] = 'A' - 1 + len;
         2023                 s = sha1(tmp, len, nil, nil);
         2024                 s = sha1(key, nkey, nil, s);
         2025                 s = sha1(seed0, nseed0, nil, s);
         2026                 sha1(seed1, nseed1, sha1dig, s);
         2027                 s = md5(key, nkey, nil, nil);
         2028                 md5(sha1dig, SHA1dlen, md5dig, s);
         2029                 n = MD5dlen;
         2030                 if(n > nbuf)
         2031                         n = nbuf;
         2032                 memmove(buf, md5dig, n);
         2033                 buf += n;
         2034                 nbuf -= n;
         2035                 len++;
         2036         }
         2037 }
         2038 
         2039 static mpint*
         2040 bytestomp(Bytes* bytes)
         2041 {
         2042         mpint* ans;
         2043 
         2044         ans = betomp(bytes->data, bytes->len, nil);
         2045         return ans;
         2046 }
         2047 
         2048 /*
         2049  * Convert mpint* to Bytes, putting high order byte first.
         2050  */
         2051 static Bytes*
         2052 mptobytes(mpint* big)
         2053 {
         2054         int n, m;
         2055         uchar *a;
         2056         Bytes* ans;
         2057 
         2058         n = (mpsignif(big)+7)/8;
         2059         m = mptobe(big, nil, n, &a);
         2060         ans = makebytes(a, m);
         2061         return ans;
         2062 }
         2063 
         2064 /* Do RSA computation on block according to key, and pad */
         2065 /* result on left with zeros to make it modlen long. */
         2066 static Bytes*
         2067 rsacomp(Bytes* block, RSApub* key, int modlen)
         2068 {
         2069         mpint *x, *y;
         2070         Bytes *a, *ybytes;
         2071         int ylen;
         2072 
         2073         x = bytestomp(block);
         2074         y = rsaencrypt(key, x, nil);
         2075         mpfree(x);
         2076         ybytes = mptobytes(y);
         2077         ylen = ybytes->len;
         2078 
         2079         if(ylen < modlen) {
         2080                 a = newbytes(modlen);
         2081                 memset(a->data, 0, modlen-ylen);
         2082                 memmove(a->data+modlen-ylen, ybytes->data, ylen);
         2083                 freebytes(ybytes);
         2084                 ybytes = a;
         2085         }
         2086         else if(ylen > modlen) {
         2087                 /* assume it has leading zeros (mod should make it so) */
         2088                 a = newbytes(modlen);
         2089                 memmove(a->data, ybytes->data, modlen);
         2090                 freebytes(ybytes);
         2091                 ybytes = a;
         2092         }
         2093         mpfree(y);
         2094         return ybytes;
         2095 }
         2096 
         2097 /* encrypt data according to PKCS#1, /lib/rfc/rfc2437 9.1.2.1 */
         2098 static Bytes*
         2099 pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype)
         2100 {
         2101         Bytes *pad, *eb, *ans;
         2102         int i, dlen, padlen, modlen;
         2103 
         2104         modlen = (mpsignif(key->n)+7)/8;
         2105         dlen = data->len;
         2106         if(modlen < 12 || dlen > modlen - 11)
         2107                 return nil;
         2108         padlen = modlen - 3 - dlen;
         2109         pad = newbytes(padlen);
         2110         genrandom(pad->data, padlen);
         2111         for(i = 0; i < padlen; i++) {
         2112                 if(blocktype == 0)
         2113                         pad->data[i] = 0;
         2114                 else if(blocktype == 1)
         2115                         pad->data[i] = 255;
         2116                 else if(pad->data[i] == 0)
         2117                         pad->data[i] = 1;
         2118         }
         2119         eb = newbytes(modlen);
         2120         eb->data[0] = 0;
         2121         eb->data[1] = blocktype;
         2122         memmove(eb->data+2, pad->data, padlen);
         2123         eb->data[padlen+2] = 0;
         2124         memmove(eb->data+padlen+3, data->data, dlen);
         2125         ans = rsacomp(eb, key, modlen);
         2126         freebytes(eb);
         2127         freebytes(pad);
         2128         return ans;
         2129 }
         2130 
         2131 /* decrypt data according to PKCS#1, with given key. */
         2132 /* expect a block type of 2. */
         2133 static Bytes*
         2134 pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm)
         2135 {
         2136         Bytes *eb, *ans = nil;
         2137         int i, modlen;
         2138         mpint *x, *y;
         2139 
         2140         modlen = (mpsignif(sec->rsapub->n)+7)/8;
         2141         if(nepm != modlen)
         2142                 return nil;
         2143         x = betomp(epm, nepm, nil);
         2144         y = factotum_rsa_decrypt(sec->rpc, x);
         2145         if(y == nil)
         2146                 return nil;
         2147         eb = mptobytes(y);
         2148         if(eb->len < modlen){ /* pad on left with zeros */
         2149                 ans = newbytes(modlen);
         2150                 memset(ans->data, 0, modlen-eb->len);
         2151                 memmove(ans->data+modlen-eb->len, eb->data, eb->len);
         2152                 freebytes(eb);
         2153                 eb = ans;
         2154         }
         2155         if(eb->data[0] == 0 && eb->data[1] == 2) {
         2156                 for(i = 2; i < modlen; i++)
         2157                         if(eb->data[i] == 0)
         2158                                 break;
         2159                 if(i < modlen - 1)
         2160                         ans = makebytes(eb->data+i+1, modlen-(i+1));
         2161         }
         2162         freebytes(eb);
         2163         return ans;
         2164 }
         2165 
         2166 
         2167 /*================= general utility functions ======================== */
         2168 
         2169 static void *
         2170 emalloc(int n)
         2171 {
         2172         void *p;
         2173         if(n==0)
         2174                 n=1;
         2175         p = malloc(n);
         2176         if(p == nil){
         2177                 exits("out of memory");
         2178         }
         2179         memset(p, 0, n);
         2180         return p;
         2181 }
         2182 
         2183 static void *
         2184 erealloc(void *ReallocP, int ReallocN)
         2185 {
         2186         if(ReallocN == 0)
         2187                 ReallocN = 1;
         2188         if(!ReallocP)
         2189                 ReallocP = emalloc(ReallocN);
         2190         else if(!(ReallocP = realloc(ReallocP, ReallocN))){
         2191                 exits("out of memory");
         2192         }
         2193         return(ReallocP);
         2194 }
         2195 
         2196 static void
         2197 put32(uchar *p, u32int x)
         2198 {
         2199         p[0] = x>>24;
         2200         p[1] = x>>16;
         2201         p[2] = x>>8;
         2202         p[3] = x;
         2203 }
         2204 
         2205 static void
         2206 put24(uchar *p, int x)
         2207 {
         2208         p[0] = x>>16;
         2209         p[1] = x>>8;
         2210         p[2] = x;
         2211 }
         2212 
         2213 static void
         2214 put16(uchar *p, int x)
         2215 {
         2216         p[0] = x>>8;
         2217         p[1] = x;
         2218 }
         2219 
         2220 /*
         2221 static u32int
         2222 get32(uchar *p)
         2223 {
         2224         return (p[0]<<24)|(p[1]<<16)|(p[2]<<8)|p[3];
         2225 }
         2226 */
         2227 
         2228 static int
         2229 get24(uchar *p)
         2230 {
         2231         return (p[0]<<16)|(p[1]<<8)|p[2];
         2232 }
         2233 
         2234 static int
         2235 get16(uchar *p)
         2236 {
         2237         return (p[0]<<8)|p[1];
         2238 }
         2239 
         2240 /* ANSI offsetof() */
         2241 #define OFFSET(x, s) ((intptr)(&(((s*)0)->x)))
         2242 
         2243 /*
         2244  * malloc and return a new Bytes structure capable of
         2245  * holding len bytes. (len >= 0)
         2246  * Used to use crypt_malloc, which aborts if malloc fails.
         2247  */
         2248 static Bytes*
         2249 newbytes(int len)
         2250 {
         2251         Bytes* ans;
         2252 
         2253         ans = (Bytes*)malloc(OFFSET(data[0], Bytes) + len);
         2254         ans->len = len;
         2255         return ans;
         2256 }
         2257 
         2258 /*
         2259  * newbytes(len), with data initialized from buf
         2260  */
         2261 static Bytes*
         2262 makebytes(uchar* buf, int len)
         2263 {
         2264         Bytes* ans;
         2265 
         2266         ans = newbytes(len);
         2267         memmove(ans->data, buf, len);
         2268         return ans;
         2269 }
         2270 
         2271 static void
         2272 freebytes(Bytes* b)
         2273 {
         2274         if(b != nil)
         2275                 free(b);
         2276 }
         2277 
         2278 /* len is number of ints */
         2279 static Ints*
         2280 newints(int len)
         2281 {
         2282         Ints* ans;
         2283 
         2284         ans = (Ints*)malloc(OFFSET(data[0], Ints) + len*sizeof(int));
         2285         ans->len = len;
         2286         return ans;
         2287 }
         2288 
         2289 /*
         2290 static Ints*
         2291 makeints(int* buf, int len)
         2292 {
         2293         Ints* ans;
         2294 
         2295         ans = newints(len);
         2296         if(len > 0)
         2297                 memmove(ans->data, buf, len*sizeof(int));
         2298         return ans;
         2299 }
         2300 */
         2301 
         2302 static void
         2303 freeints(Ints* b)
         2304 {
         2305         if(b != nil)
         2306                 free(b);
         2307 }