diff --git a/src/sample_pages_new.py b/src/sample_pages_new.py index 245d66c..a398d2d 100644 --- a/src/sample_pages_new.py +++ b/src/sample_pages_new.py @@ -5,23 +5,30 @@ import numpy as np import pandas as pd from lxml import etree -import argparse, progressbar, hashlib +import argparse +import tqdm +import hashlib -from pyriksdagen.utils import infer_metadata, protocol_iterators +from pyriksdagen.utils import corpus_iterator tei_ns = "{http://www.tei-c.org/ns/1.0}" xml_ns = "{http://www.w3.org/XML/1998/namespace}" +from trainerlog import get_logger +LOGGER = get_logger("sample-pages") + def get_date(root): for docDate in root.findall(f".//{tei_ns}docDate"): date_string = docDate.text + date_string = " ".join(date_string.split()).strip() break return date_string def get_page_counts(corpus_path="corpus/protocols/"): + LOGGER.info("Load records in to calculate page counts...") parser = etree.XMLParser(remove_blank_text=True) rows = [] - for protocol_path in progressbar.progressbar(list(protocol_iterators(corpus_path, start=args.start, end=args.end))): + for protocol_path in tqdm.tqdm(list(corpus_iterator("prot", corpus_root=corpus_path, start=args.start, end=args.end))): root = etree.parse(protocol_path, parser) pbs = root.findall(f".//{tei_ns}pb") year = get_date(root)[:4] @@ -31,13 +38,6 @@ def get_page_counts(corpus_path="corpus/protocols/"): df = pd.DataFrame(rows, columns=["protocol_path", "protocol_id", "year", "pages"]) return df -def get_pagenumber(link): - link = link.replace(".jp2/_view", "") - link = link.split("-")[-1] - link = link.split("page=")[-1] - if link.isnumeric(): - return int(link) - def sample_page_counts(df, start, end, n, seed=None): df = df[df["year"] >= start] df = df[df["year"] <= end].copy() @@ -133,25 +133,28 @@ def flatten(df): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument('--records_folder', type=str, default="corpus/protocols") parser.add_argument('--qc_folder', type=str, default="input/quality-control") - parser.add_argument("-f", '--seed', type=str, default=None, help="Random state seed") + parser.add_argument("-f", '--seed', type=str, required=True, help="Random state seed") parser.add_argument("-b", "--branch", type=str, default="main", help="Github branch where curation is happening.") parser.add_argument('-p', '--pages_per_decade', type=int, default=30, help="How many pages per decade? 30") parser.add_argument("-s", "--start", type=int, default=1920, help="Start year") parser.add_argument("-e", "--end", type=int, default=2022, help="End year") parser.add_argument("--flatten", type=bool, default=False, help="Flatten output to only contain pages instead of elements") + parser.add_argument("--output_file", type=str, default=None, help="Write output here, to a single CSV file, intead of one per decade") args = parser.parse_args() + LOGGER.train(f"Args: {args}") digest = hashlib.md5(args.seed.encode("utf-8")).digest() digest = int.from_bytes(digest, "big") % (2**32) path = args.records_folder protocol_df = get_page_counts(path) - print(protocol_df) + LOGGER.info(f"Do sampling for the following records:\n{protocol_df}") + all_samples = [] for decade in range(args.start // 10 * 10, args.end, 10): - print("Decade:", decade) + LOGGER.info(f"Decade: {decade}") sample = sample_page_counts(protocol_df, decade, decade + 9, n=args.pages_per_decade, seed=digest) - print(sample) + LOGGER.info(f"Sample:\n{sample}") prng = np.random.RandomState( (digest+decade) % (2**32)) sample = sample_pages(sample, random_state=prng) @@ -169,9 +172,16 @@ def flatten(df): if args.flatten: sample = flatten(sample) - sample.to_csv(f"{args.qc_folder}/sample_{decade}.csv", index=False) - - protocols_unique = list(sample.protocol_id.unique()) - with open(f"{args.qc_folder}/sample_{decade}.txt", "w+") as outf: - for up in protocols_unique: - outf.write(f"{args.records_folder}/{up.split('-')[1]}/{up}.xml\n") + if args.output_file is None: + sample.to_csv(f"{args.qc_folder}/sample_{decade}.csv", index=False) + + protocols_unique = list(sample.protocol_id.unique()) + with open(f"{args.qc_folder}/sample_{decade}.txt", "w+") as outf: + for up in protocols_unique: + outf.write(f"{args.records_folder}/{up.split('-')[1]}/{up}.xml\n") + else: + all_samples.append(sample) + + if args.output_file is not None: + sample = pd.concat(all_samples) + sample.to_csv(args.output_file)