Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions src/sample_pages_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,30 @@
import numpy as np
import pandas as pd
from lxml import etree
import argparse, progressbar, hashlib
import argparse
import tqdm
import hashlib

from pyriksdagen.utils import infer_metadata, protocol_iterators
from pyriksdagen.utils import corpus_iterator

tei_ns = "{http://www.tei-c.org/ns/1.0}"
xml_ns = "{http://www.w3.org/XML/1998/namespace}"

from trainerlog import get_logger
LOGGER = get_logger("sample-pages")

def get_date(root):
for docDate in root.findall(f".//{tei_ns}docDate"):
date_string = docDate.text
date_string = " ".join(date_string.split()).strip()
break
return date_string

def get_page_counts(corpus_path="corpus/protocols/"):
LOGGER.info("Load records in to calculate page counts...")
parser = etree.XMLParser(remove_blank_text=True)
rows = []
for protocol_path in progressbar.progressbar(list(protocol_iterators(corpus_path, start=args.start, end=args.end))):
for protocol_path in tqdm.tqdm(list(corpus_iterator("prot", corpus_root=corpus_path, start=args.start, end=args.end))):
root = etree.parse(protocol_path, parser)
pbs = root.findall(f".//{tei_ns}pb")
year = get_date(root)[:4]
Expand All @@ -31,13 +38,6 @@ def get_page_counts(corpus_path="corpus/protocols/"):
df = pd.DataFrame(rows, columns=["protocol_path", "protocol_id", "year", "pages"])
return df

def get_pagenumber(link):
link = link.replace(".jp2/_view", "")
link = link.split("-")[-1]
link = link.split("page=")[-1]
if link.isnumeric():
return int(link)

def sample_page_counts(df, start, end, n, seed=None):
df = df[df["year"] >= start]
df = df[df["year"] <= end].copy()
Expand Down Expand Up @@ -133,25 +133,28 @@ def flatten(df):
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--records_folder', type=str, default="corpus/protocols")
parser.add_argument('--qc_folder', type=str, default="input/quality-control")
parser.add_argument("-f", '--seed', type=str, default=None, help="Random state seed")
parser.add_argument("-f", '--seed', type=str, required=True, help="Random state seed")
parser.add_argument("-b", "--branch", type=str, default="main", help="Github branch where curation is happening.")
parser.add_argument('-p', '--pages_per_decade', type=int, default=30, help="How many pages per decade? 30")
parser.add_argument("-s", "--start", type=int, default=1920, help="Start year")
parser.add_argument("-e", "--end", type=int, default=2022, help="End year")
parser.add_argument("--flatten", type=bool, default=False, help="Flatten output to only contain pages instead of elements")
parser.add_argument("--output_file", type=str, default=None, help="Write output here, to a single CSV file, intead of one per decade")
args = parser.parse_args()
LOGGER.train(f"Args: {args}")

digest = hashlib.md5(args.seed.encode("utf-8")).digest()
digest = int.from_bytes(digest, "big") % (2**32)

path = args.records_folder
protocol_df = get_page_counts(path)
print(protocol_df)
LOGGER.info(f"Do sampling for the following records:\n{protocol_df}")

all_samples = []
for decade in range(args.start // 10 * 10, args.end, 10):
print("Decade:", decade)
LOGGER.info(f"Decade: {decade}")
sample = sample_page_counts(protocol_df, decade, decade + 9, n=args.pages_per_decade, seed=digest)
print(sample)
LOGGER.info(f"Sample:\n{sample}")

prng = np.random.RandomState( (digest+decade) % (2**32))
sample = sample_pages(sample, random_state=prng)
Expand All @@ -169,9 +172,16 @@ def flatten(df):
if args.flatten:
sample = flatten(sample)

sample.to_csv(f"{args.qc_folder}/sample_{decade}.csv", index=False)

protocols_unique = list(sample.protocol_id.unique())
with open(f"{args.qc_folder}/sample_{decade}.txt", "w+") as outf:
for up in protocols_unique:
outf.write(f"{args.records_folder}/{up.split('-')[1]}/{up}.xml\n")
if args.output_file is None:
sample.to_csv(f"{args.qc_folder}/sample_{decade}.csv", index=False)

protocols_unique = list(sample.protocol_id.unique())
with open(f"{args.qc_folder}/sample_{decade}.txt", "w+") as outf:
for up in protocols_unique:
outf.write(f"{args.records_folder}/{up.split('-')[1]}/{up}.xml\n")
else:
all_samples.append(sample)

if args.output_file is not None:
sample = pd.concat(all_samples)
sample.to_csv(args.output_file)
Loading