Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 26 additions & 23 deletions simp_le.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,14 @@ class IOPlugin(object):
- for `chain`: certificate chain, a list of `OpenSSL.crypto.X509` instances
"""

@classmethod
def set_data(cls, account_key=None, key=None, cert=None, chain=None):
return cls.Data(account_key, key, cert, chain, challenge)

@classmethod
def set_data_bool(cls, account_key=False, key=False, cert=False, chain=False):
return cls.Data(account_key, key, cert, chain, challenge)

EMPTY_DATA = Data(account_key=None, key=None, cert=None, chain=None)

def __init__(self, path, **dummy_kwargs):
Expand Down Expand Up @@ -420,11 +428,10 @@ class AccountKey(FileIOPlugin, JWKIOPlugin):
WRITE_MODE = 'w'

def persisted(self):
return self.Data(account_key=True, key=False, cert=False, chain=False)
return self.set_data_bool(account_key=True)

def load_from_content(self, content):
return self.Data(account_key=self.load_jwk(content), key=None,
cert=None, chain=None)
return self.set_data(account_key=self.load_jwk(content))

def save(self, data):
return self.save_to_file(self.dump_jwk(data.account_key))
Expand Down Expand Up @@ -525,7 +532,7 @@ def get_output_or_fail(self, command):
def persisted(self):
"""Call the external script and see which data is persisted."""
output = self.get_output_or_fail('persisted').split()
return self.Data(
return self.set_data_bool(
account_key=(b'account_key' in output),
key=(b'key' in output),
cert=(b'cert' in output),
Expand All @@ -545,7 +552,7 @@ def load(self):
cert = self.load_cert(pems.pop(0)) if persisted.cert else None
chain = ([self.load_cert(cert_data) for cert_data in pems]
if persisted.chain else None)
return self.Data(account_key=account_key, key=key,
return self.set_data(account_key=account_key, key=key,
cert=cert, chain=chain)

def save(self, data):
Expand Down Expand Up @@ -684,12 +691,12 @@ class ChainFile(FileIOPlugin, OpenSSLIOPlugin):
"""Certificate chain plugin."""

def persisted(self):
return self.Data(account_key=False, key=False, cert=False, chain=True)
return self.set_data_bool(chain=True)

def load_from_content(self, output):
chain = [self.load_cert(cert_data)
for cert_data in split_pems(output)]
return self.Data(account_key=None, key=None, cert=None, chain=chain)
return self.set_data(chain=chain)

def save(self, data):
return self.save_to_file(_PEMS_SEP.join(
Expand All @@ -707,21 +714,21 @@ class FullChainFile(ChainFile):
"""Full chain file plugin."""

def persisted(self):
return self.Data(account_key=False, key=False, cert=True, chain=True)
return self.set_data_bool(cert=True, chain=True)

def load(self):
data = super(FullChainFile, self).load()
if data.chain is None:
cert, chain = None, None
else:
cert, chain = data.chain[0], data.chain[1:]
return self.Data(account_key=data.account_key, key=data.key,
return self.set_data(account_key=data.account_key, key=data.key,
cert=cert, chain=chain)

def save(self, data):
return super(FullChainFile, self).save(self.Data(
return super(FullChainFile, self).save(self.set_data(
account_key=data.account_key, key=data.key,
cert=None, chain=([data.cert] + data.chain)))
chain=([data.cert] + data.chain)))


class FullChainFileTest(FileIOPluginTestMixin, UnitTestCase):
Expand All @@ -736,11 +743,10 @@ class KeyFile(FileIOPlugin, OpenSSLIOPlugin):
"""Private key file plugin."""

def persisted(self):
return self.Data(account_key=False, key=True, cert=False, chain=False)
return self.set_data_bool(key=True)

def load_from_content(self, output):
return self.Data(account_key=None, key=self.load_key(output),
cert=None, chain=None)
return self.set_data(key=self.load_key(output))

def save(self, data):
return self.save_to_file(self.dump_key(data.key))
Expand All @@ -758,11 +764,10 @@ class CertFile(FileIOPlugin, OpenSSLIOPlugin):
"""Certificate file plugin."""

def persisted(self):
return self.Data(account_key=False, key=False, cert=True, chain=False)
return self.set_data_bool(cert=True)

def load_from_content(self, output):
return self.Data(account_key=None, key=None,
cert=self.load_cert(output), chain=None)
return self.set_data(cert=self.load_cert(output))

def save(self, data):
return self.save_to_file(self.dump_cert(data.cert))
Expand All @@ -779,12 +784,11 @@ class FullFile(FileIOPlugin, OpenSSLIOPlugin):
"""Private key, certificate and chain plugin."""

def persisted(self):
return self.Data(account_key=False, key=True, cert=True, chain=True)
return self.set_data_bool(key=True, cert=True, chain=True)

def load_from_content(self, content):
pems = split_pems(content)
return self.Data(
account_key=None,
return self.set_data(
key=self.load_key(next(pems)),
cert=self.load_cert(next(pems)),
chain=[self.load_cert(cert) for cert in pems],
Expand Down Expand Up @@ -1107,8 +1111,7 @@ def integration_test(args):

def check_plugins_persist_all(ioplugins):
"""Do plugins cover all components (key/cert/chain)?"""
persisted = IOPlugin.Data(
account_key=False, key=False, cert=False, chain=False)
persisted = IOPlugin.set_data_bool()
for plugin_name in ioplugins:
persisted = IOPlugin.Data(*componentwise_or(
persisted, IOPlugin.registered[plugin_name].persisted()))
Expand Down Expand Up @@ -1314,7 +1317,7 @@ def persist_new_data(args, existing_data):
key = ComparablePKey(gen_pkey(args.cert_key_size))
csr = gen_csr(key.wrapped, [vhost.name.encode() for vhost in args.vhosts])
certr = get_certr(client, csr, authorizations)
persist_data(args, existing_data, new_data=IOPlugin.Data(
persist_data(args, existing_data, new_data=IOPlugin.set_data(
account_key=client.key, key=key,
cert=certr.body, chain=client.fetch_chain(certr)))

Expand Down