diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 6833ae76..321e2737 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -51,6 +51,6 @@ jobs: shell: cmd run: | pip install pytest - pytest test\util.py test\metadata.py test\integration\dummy.py test\integration\vaswani.py test\formats\ + pytest test\util.py test\metadata.py test\integration\dummy.py test\integration\vaswani.py test\formats\ test\test_local.py env: PATH: 'C:/Program Files/zlib/bin/' diff --git a/ir_datasets/__init__.py b/ir_datasets/__init__.py index 48c19007..dd94c32a 100644 --- a/ir_datasets/__init__.py +++ b/ir_datasets/__init__.py @@ -6,20 +6,26 @@ class EntityType(Enum): scoreddocs = "scoreddocs" docpairs = "docpairs" qlogs = "qlogs" + +from . import util +registry = util.Registry() + +def load(name): + return registry[name] + from . import lazy_libs from . import log -from . import util from . import formats -registry = util.Registry() from . import datasets from . import indices from . import wrappers from . import commands -Dataset = datasets.base.Dataset +create_local_dataset = datasets.create_local_dataset +delete_local_dataset = datasets.delete_local_dataset +iter_local_datasets = datasets.iter_local_datasets -def load(name): - return registry[name] +Dataset = datasets.base.Dataset def parent_id(dataset_id: str, entity_type: EntityType) -> str: diff --git a/ir_datasets/datasets/__init__.py b/ir_datasets/datasets/__init__.py index baa84b31..0b3e51ea 100644 --- a/ir_datasets/datasets/__init__.py +++ b/ir_datasets/datasets/__init__.py @@ -44,5 +44,7 @@ from . import wapo from . import wikir from . import trec_fair_2021 + +from .local import iter_local_datasets, create_local_dataset, delete_local_dataset from . import trec_cast # must be after wapo,car,msmarco_passage from . import hc4 diff --git a/ir_datasets/datasets/base.py b/ir_datasets/datasets/base.py index 0e3149fc..2303afcb 100644 --- a/ir_datasets/datasets/base.py +++ b/ir_datasets/datasets/base.py @@ -97,6 +97,19 @@ def has_docpairs(self): def has_qlogs(self): return self.has(ir_datasets.EntityType.qlogs) + def handler(self, etype: ir_datasets.EntityType): + etype = ir_datasets.EntityType(etype) # validate & allow strings + return getattr(self, f'{etype.value}_handler')() + + def clear_cache(self): + for bapi in self._beta_apis.values(): + if hasattr(bapi, 'clear_cache'): + bapi.clear_cache() + self._beta_apis.clear() + for c in self._constituents: + if hasattr(c, 'clear_cache'): + c.clear_cache() + class _BetaPythonApiDocs: def __init__(self, handler): @@ -136,6 +149,11 @@ def lookup_iter(self, doc_ids): def metadata(self): return self._handler.docs_metadata() + def clear_cache(self): + if self._docstore is not None and hasattr(self._docstore, 'clear_cache'): + self._docstore.clear_cache() + self._docstore = None + class _BetaPythonApiQueries: def __init__(self, handler): diff --git a/ir_datasets/datasets/local.py b/ir_datasets/datasets/local.py new file mode 100644 index 00000000..61d3ca51 --- /dev/null +++ b/ir_datasets/datasets/local.py @@ -0,0 +1,296 @@ +from .base import Dataset +import itertools +import pickle +import json +import shutil +import uuid +from typing import NamedTuple +import ir_datasets +from ir_datasets import EntityType +from ir_datasets.indices import Lz4PickleLookup, PickleLz4FullStore +from ir_datasets.formats import BaseDocs, BaseQueries, BaseQrels, BaseScoredDocs, BaseDocPairs, BaseQlogs + +logger = ir_datasets.log.easy() + +NAME = 'local' + +BASE_PATH = ir_datasets.util.home_path()/NAME + + +class BaseLocal: + def __init__(self, path, etype): + self._path = path + self._etype = etype + self._cls_ = None + + def _iter(self): + cls = self._cls() + lz4_frame = ir_datasets.lazy_libs.lz4_frame().frame + with lz4_frame.LZ4FrameFile(self._path/'data', 'rb') as fin: + while fin.peek(1): + yield cls(*pickle.load(fin)) + + def _count(self): + with (self._path/'count.pkl').open('rb') as fin: + return pickle.load(fin) + + def _cls(self): + if self._cls_ is None: + with (self._path/'type.pkl').open('rb') as fin: + name, attrs = pickle.load(fin) + self._cls_ = NamedTuple(name, list(attrs.items())) + return self._cls_ + + @classmethod + def create(cls, path, records): + path.mkdir(exist_ok=True, parents=True) + records = iter(records) + first = next(records) + if isinstance(first, dict): + EntityCls = NamedTuple('LocalEntity', [(k, type(v)) for k, v in first.items()]) + records = itertools.chain([EntityCls(**first)], (EntityCls(**r) for r in records)) + else: + EntityCls = type(first) + records = itertools.chain([first], records) + with ir_datasets.util.finialized_file(path/'type.pkl', 'wb') as fout: + pickle.dump((EntityCls.__name__, EntityCls.__annotations__), fout) + count = cls._build_datafile(path/'data', records, EntityCls) + with ir_datasets.util.finialized_file(path/'count.pkl', 'wb') as fout: + pickle.dump(count, fout) + + @classmethod + def _build_datafile(cls, path, records, ecls): + count = 0 + lz4_frame = ir_datasets.lazy_libs.lz4_frame().frame + with ir_datasets.util.finialized_file(path, 'wb') as raw_fout, \ + lz4_frame.LZ4FrameFile(raw_fout, 'wb') as fout: + for record in records: + pickle.dump(tuple(record), fout) + count += 1 + return count + + +class LocalDocs(BaseLocal, BaseDocs): + def __init__(self, path): + super().__init__(path, EntityType.docs) + + def docs_iter(self): + return iter(self.docs_store()) + + def docs_store(self, field='doc_id'): + return PickleLz4FullStore( + path=f'{self._path}/data', + init_iter_fn=None, + data_cls=self._cls(), + lookup_field='doc_id', + index_fields=['doc_id'], + count_hint=self._count, + ) + + def docs_count(self): + return len(self.docs_store()) + + def docs_cls(self): + return self._cls() + + @classmethod + def _build_datafile(cls, path, records, ecls): + lookup = Lz4PickleLookup(path, ecls, 'doc_id', ['doc_id']) + with lookup.transaction() as trans: + for doc in records: + trans.add(doc) + + +class LocalQueries(BaseLocal, BaseQueries): + def __init__(self, path): + super().__init__(path, EntityType.queries) + + def queries_iter(self): + return self._iter() + + def queries_count(self): + return self._count() + + def queries_cls(self): + return self._cls() + + +class LocalQrels(BaseLocal, BaseQrels): + def __init__(self, path): + super().__init__(path, EntityType.qrels) + + def qrels_iter(self): + return self._iter() + + def qrels_count(self): + return self._count() + + def qrels_cls(self): + return self._cls() + + def qrels_defs(self): + return {} + + +class LocalScoredDocs(BaseLocal, BaseScoredDocs): + def __init__(self, path): + super().__init__(path, EntityType.scoreddocs) + + def scoreddocs_iter(self): + return self._iter() + + def scoreddocs_count(self): + return self._count() + + def scoreddocs_cls(self): + return self._cls() + + +class LocalDocpairs(BaseLocal, BaseDocPairs): + def __init__(self, path): + super().__init__(path, EntityType.docpairs) + + def docpairs_iter(self): + return self._iter() + + def docpairs_count(self): + return self._count() + + def docpairs_cls(self): + return self._cls() + + +class LocalQlogs(BaseLocal, BaseQlogs): + def __init__(self, path): + super().__init__(path, EntityType.qlogs) + + def qlogs_iter(self): + return self._iter() + + def qlogs_count(self): + return self._count() + + def qlogs_cls(self): + return self._cls() + + +PROVIDERS = { + EntityType.docs: LocalDocs, + EntityType.queries: LocalQueries, + EntityType.qrels: LocalQrels, + EntityType.scoreddocs: LocalScoredDocs, + EntityType.docpairs: LocalDocpairs, + EntityType.qlogs: LocalQlogs, +} + + +def create_local_dataset(dataset_id, **sources): + if dataset_id in ir_datasets.registry: + raise KeyError(f'{dataset_id} already in registry; choose another name') + path = str(uuid.uuid4()) + ds_path = BASE_PATH/path + components = [] + dataset_record = { + 'id': dataset_id, + 'path': path, + 'provides': {}, + } + with logger.duration(f'provisioning {dataset_id} to {ds_path}'): + ds_path.mkdir(exist_ok=True, parents=True) + for etype in EntityType: + if etype.value not in sources: + continue + source = sources[etype.value] + if isinstance(source, str): + components.append(ir_datasets.load(source).handler(etype)) + dataset_record['provides'][etype.value] = source + else: + e_path = ds_path/etype.value + provider_cls = PROVIDERS[EntityType(etype)] + with logger.duration(f'creating {str(e_path)}'): + provider_cls.create(e_path, source) + components.append(provider_cls(e_path)) + dataset_record['provides'][etype.value] = None + del sources[etype.value] + if sources: + raise RuntimeError(f'Unexpected argument(s): {sources.keys()}') + dataset = Dataset(*components) + ir_datasets.registry.register(dataset_id, dataset) + registry_path = (BASE_PATH/'registry.json') + if registry_path.exists(): + with registry_path.open('rt') as fin: + registry = json.load(fin) + else: + registry = [] + registry.append(dataset_record) + with ir_datasets.util.finialized_file(registry_path, 'wt') as fout: + fout.write('[\n') + for item in registry: + fout.write(' ' + json.dumps(item)) + if item != registry[-1]: + fout.write(',') + fout.write('\n') + fout.write(']\n') + return dataset + + +def delete_local_dataset(dataset_id, remove_files=True): + registry_paths = list(BASE_PATH.glob('registry*.json')) + for registry_path in sorted(registry_paths): + with registry_path.open('rt') as fin: + registry = json.load(fin) + changed = False + for dataset in list(registry): + if dataset['id'] == dataset_id: + registry.remove(dataset) + changed = True + try: + ds = ir_datasets.registry[dataset_id] + ds.clear_cache() + del ir_datasets.registry[dataset_id] + del ds # clean up this dataset (e.g., close open files) + except KeyError: + pass + if remove_files: + with logger.duration(f'Removing {dataset["id"]} at {str(BASE_PATH/dataset["path"])}'): + shutil.rmtree(BASE_PATH/dataset['path']) + if changed: + with ir_datasets.util.finialized_file(registry_path, 'wt') as fout: + fout.write('[\n') + for item in registry: + fout.write(' ' + json.dumps(item)) + if item != registry[-1]: + fout.write(',') + fout.write('\n') + fout.write(']\n') + + +def iter_local_datasets(): + registry_paths = list(BASE_PATH.glob('registry*.json')) + for registry_path in sorted(registry_paths): + with registry_path.open('rt') as fin: + registry = json.load(fin) + changed = False + for dataset in list(registry): + dataset['registry_path'] = str(registry_path) + yield dataset + + +def _init(): + for dataset in iter_local_datasets(): + if dataset['id'] in ir_datasets.registry: + new_id = f'local/{dataset["path"]}' + logger.warn(f'Local dataset {repr(dataset["id"])} from {dataset["registry_path"]} already in registry. ' + f'Renaming it to {new_id}.') + dataset['id'] = new_id + ds_path = BASE_PATH/dataset['path'] + components = [] + for etype, val in dataset['provides'].items(): + if val is None: + provider_cls = PROVIDERS[EntityType(etype)] + components.append(provider_cls(ds_path/etype)) + else: + components.append(ir_datasets.load(val).handler(etype)) + ir_datasets.registry.register(dataset['id'], Dataset(*components)) + +_init() diff --git a/ir_datasets/indices/lz4_pickle.py b/ir_datasets/indices/lz4_pickle.py index 3f7e3c89..4f12101c 100644 --- a/ir_datasets/indices/lz4_pickle.py +++ b/ir_datasets/indices/lz4_pickle.py @@ -51,6 +51,7 @@ def __init__(self, lookup, slice): def __next__(self): if self.slice.start >= self.slice.stop: + self.clear() raise StopIteration if self.bin is None: self.bin = open(self.lookup._bin_path, 'rb') @@ -70,6 +71,9 @@ def __iter__(self): return self def __del__(self): + self.clear() + + def clear(self): if self.bin is not None: self.bin.close() self.bin = None @@ -150,6 +154,9 @@ def clear(self): os.remove(self._pos_path) NumpySortedIndex(self._idx_path).clear() + def clear_cache(self): + self.close() + def __del__(self): self.close() diff --git a/ir_datasets/util/registry.py b/ir_datasets/util/registry.py index 72f756fe..005b4eaf 100644 --- a/ir_datasets/util/registry.py +++ b/ir_datasets/util/registry.py @@ -44,3 +44,15 @@ def register(self, name, obj): def register_pattern(self, pattern, initializer): self._patterns.append((re.compile(pattern), initializer)) + + def __contains__(self, key): + if key in self._registered: + return True + for pattern, initializer in self._patterns: + match = pattern.match(key) + if match: + return True + return False + + def __delitem__(self, key): + del self._registered[key] diff --git a/test/test_local.py b/test/test_local.py new file mode 100644 index 00000000..1a120dd2 --- /dev/null +++ b/test/test_local.py @@ -0,0 +1,66 @@ +import unittest +import ir_datasets + + +class TestLocal(unittest.TestCase): + def test_local(self): + docs = [ + {'doc_id': '1', 'text': 'hello world'}, + {'doc_id': '2', 'text': 'some document'} + ] + queries = [ + {'query_id': 'Q1', 'text': 'search'}, + {'query_id': 'Q2', 'text': 'information retrieval'} + ] + qrels = [ + {'query_id': 'Q1', 'doc_id': '1', 'relevance': 1}, + {'query_id': 'Q1', 'doc_id': '2', 'relevance': 2}, + {'query_id': 'Q2', 'doc_id': '2', 'relevance': 0}, + ] + try: + ir_datasets.create_local_dataset('_testlocal', docs=docs) + ir_datasets.create_local_dataset('_testlocal/subset', docs='_testlocal', queries=queries, qrels=qrels) + + test_docs = list(ir_datasets.load('_testlocal').docs) + self.assertEqual(test_docs[0].doc_id, '1') + self.assertEqual(test_docs[0].text, 'hello world') + self.assertEqual(test_docs[1].doc_id, '2') + self.assertEqual(test_docs[1].text, 'some document') + + test_docs = list(ir_datasets.load('_testlocal/subset').docs) + self.assertEqual(test_docs[0].doc_id, '1') + self.assertEqual(test_docs[0].text, 'hello world') + self.assertEqual(test_docs[1].doc_id, '2') + self.assertEqual(test_docs[1].text, 'some document') + + self.assertEqual(ir_datasets.load('_testlocal').docs.lookup('1').text, 'hello world') + self.assertEqual(ir_datasets.load('_testlocal').docs.lookup('2').text, 'some document') + + self.assertEqual(ir_datasets.load('_testlocal/subset').docs.lookup('1').text, 'hello world') + self.assertEqual(ir_datasets.load('_testlocal/subset').docs.lookup('2').text, 'some document') + + test_queries = list(ir_datasets.load('_testlocal/subset').queries) + self.assertEqual(test_queries[0].query_id, 'Q1') + self.assertEqual(test_queries[0].text, 'search') + self.assertEqual(test_queries[1].query_id, 'Q2') + self.assertEqual(test_queries[1].text, 'information retrieval') + + test_qrels = list(ir_datasets.load('_testlocal/subset').qrels) + self.assertEqual(test_qrels[0].query_id, 'Q1') + self.assertEqual(test_qrels[0].doc_id, '1') + self.assertEqual(test_qrels[0].relevance, 1) + self.assertEqual(test_qrels[1].query_id, 'Q1') + self.assertEqual(test_qrels[1].doc_id, '2') + self.assertEqual(test_qrels[1].relevance, 2) + self.assertEqual(test_qrels[2].query_id, 'Q2') + self.assertEqual(test_qrels[2].doc_id, '2') + self.assertEqual(test_qrels[2].relevance, 0) + finally: + ir_datasets.load('_testlocal').clear_cache() + ir_datasets.load('_testlocal/subset').clear_cache() + ir_datasets.delete_local_dataset('_testlocal') + ir_datasets.delete_local_dataset('_testlocal/subset') + + +if __name__ == '__main__': + unittest.main()