diff --git a/cmd/request_cert.go b/cmd/request_cert.go index 5e954d8..dedcced 100644 --- a/cmd/request_cert.go +++ b/cmd/request_cert.go @@ -27,6 +27,10 @@ import ( "github.com/codegangsta/cli" "github.com/square/certstrap/depot" "github.com/square/certstrap/pkix" + + "strconv" + "encoding/asn1" + x509pkix "crypto/x509/pkix" ) // NewCertRequestCommand sets up a "request-cert" command to create a request for a new certificate (CSR) @@ -49,6 +53,8 @@ func NewCertRequestCommand() cli.Command { cli.StringFlag{"uri", "", "URI for subject alt name (comma separated)", ""}, cli.StringFlag{"key", "", "Path to private key PEM file. If blank, will generate new keypair.", ""}, cli.BoolFlag{"stdout", "Print signing request to stdout in addition to saving file", ""}, + + cli.StringFlag{"eku", "", "Comma-separated list of EKU OIDs. If the anyExtendedKeyUsage OID (2.5.29.37.0) is not in this list, the extension will be marked critical.", ""}, }, Action: newCertAction, } @@ -129,7 +135,58 @@ func newCertAction(c *cli.Context) { } } - csr, err := pkix.CreateCertificateSigningRequest(key, c.String("organizational-unit"), ips, domains, uris, c.String("organization"), c.String("country"), c.String("province"), c.String("locality"), name) + var ekuExtension *x509pkix.Extension + if c.IsSet("eku") { + var anyEKUOid asn1.ObjectIdentifier = asn1.ObjectIdentifier{2, 5, 29, 37, 0} + var sawAnyEKUOid bool = false + var oids []asn1.ObjectIdentifier + for _, oidString := range strings.Split(c.String("eku"), ",") { + var thisOid asn1.ObjectIdentifier + var isAnyEKUOid bool = true + for i, oidComponent := range strings.Split(oidString, ".") { + var thisOidComponent int + if val, err := strconv.Atoi(oidComponent); err == nil { + thisOidComponent = val + } else { + fmt.Fprintln(os.Stderr, "OID component parsing error:", err) + os.Exit(1) + } + thisOid = append(thisOid, thisOidComponent) + if i < len(anyEKUOid) { + isAnyEKUOid = isAnyEKUOid && (thisOidComponent == anyEKUOid[i]) + } else { + isAnyEKUOid = false + } + } + if len(thisOid) > 0 { + oids = append(oids, thisOid) + sawAnyEKUOid = sawAnyEKUOid || isAnyEKUOid + } + } + if len(oids) > 0 { + if val, err := asn1.Marshal(oids); err == nil { + ekuExtension = &x509pkix.Extension{ + Id: asn1.ObjectIdentifier{2, 5, 29, 37}, + Critical: !sawAnyEKUOid, + Value: val, + } + } else { + fmt.Fprintln(os.Stderr, "Error marshalling EKU extension:", err) + os.Exit(1) + } + } + } + + var extensionsList []x509pkix.Extension + var extensions *[]x509pkix.Extension = nil + if ekuExtension != nil { + extensionsList = append(extensionsList, *ekuExtension) + } + if len(extensionsList) > 0 { + extensions = &extensionsList + } + + csr, err := pkix.CreateCertificateSigningRequest(key, c.String("organizational-unit"), ips, domains, uris, c.String("organization"), c.String("country"), c.String("province"), c.String("locality"), name, extensions) if err != nil { fmt.Fprintln(os.Stderr, "Create certificate request error:", err) os.Exit(1) diff --git a/pkix/cert_auth.go b/pkix/cert_auth.go index 68c8219..4b70e48 100644 --- a/pkix/cert_auth.go +++ b/pkix/cert_auth.go @@ -146,6 +146,8 @@ func CreateIntermediateCertificateAuthority(crtAuth *Certificate, keyAuth *Key, authTemplate.DNSNames = rawCsr.DNSNames authTemplate.URIs = rawCsr.URIs + authTemplate.ExtraExtensions = rawCsr.Extensions + rawCrtAuth, err := crtAuth.GetRawCertificate() if err != nil { return nil, err diff --git a/pkix/cert_host.go b/pkix/cert_host.go index dcbd71c..d9fa056 100644 --- a/pkix/cert_host.go +++ b/pkix/cert_host.go @@ -97,6 +97,8 @@ func CreateCertificateHost(crtAuth *Certificate, keyAuth *Key, csr *CertificateS hostTemplate.DNSNames = rawCsr.DNSNames hostTemplate.URIs = rawCsr.URIs + hostTemplate.ExtraExtensions = rawCsr.Extensions + rawCrtAuth, err := crtAuth.GetRawCertificate() if err != nil { return nil, err diff --git a/pkix/csr.go b/pkix/csr.go index 98a29fe..b3c6211 100644 --- a/pkix/csr.go +++ b/pkix/csr.go @@ -92,8 +92,7 @@ func ParseAndValidateURIs(uriList string) (res []*url.URL, err error) { } // CreateCertificateSigningRequest sets up a request to create a csr file with the given parameters -func CreateCertificateSigningRequest(key *Key, organizationalUnit string, ipList []net.IP, domainList []string, uriList []*url.URL, organization string, country string, province string, locality string, commonName string) (*CertificateSigningRequest, error) { - +func CreateCertificateSigningRequest(key *Key, organizationalUnit string, ipList []net.IP, domainList []string, uriList []*url.URL, organization string, country string, province string, locality string, commonName string, extensions *[]pkix.Extension) (*CertificateSigningRequest, error) { csrPkixName.CommonName = commonName if len(organizationalUnit) > 0 { @@ -117,6 +116,9 @@ func CreateCertificateSigningRequest(key *Key, organizationalUnit string, ipList DNSNames: domainList, URIs: uriList, } + if extensions != nil { + (*csrTemplate).ExtraExtensions = *extensions + } csrBytes, err := x509.CreateCertificateRequest(rand.Reader, csrTemplate, key.Private) if err != nil { diff --git a/pkix/csr_test.go b/pkix/csr_test.go index 29099f2..d6e61c2 100644 --- a/pkix/csr_test.go +++ b/pkix/csr_test.go @@ -79,7 +79,7 @@ func TestCreateCertificateSigningRequest(t *testing.T) { t.Fatal("Failed creating rsa key:", err) } - csr, err := CreateCertificateSigningRequest(key, csrHostname, nil, nil, nil, "example", "US", "California", "San Francisco", csrCN) + csr, err := CreateCertificateSigningRequest(key, csrHostname, nil, nil, nil, "example", "US", "California", "San Francisco", csrCN, nil) if err != nil { t.Fatal("Failed creating certificate request:", err) }