Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions helical/models/geneformer/geneformer_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,95 @@ def __init__(
# protein-coding and miRNA gene list dictionary for selecting .h5ad columns for tokenization
self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))

def get_gene_ranks(self, adata_obj, gene_names="index"):
"""
Compute per-cell gene ranks obtained from standard rank value encoding.


Parameters
----------
adata_obj : AnnData
Raw counts scRNAseq data. Gene symbols will be mapped to
Ensembl IDs if 'ensembl_id' is not already in adata.var.
gene_names : str
Column in adata.var containing gene names, or "index" to use
var_names. Set to "ensembl_id" if the column already exists.

Returns
-------
rank_matrix : np.ndarray
Array of shape (n_cells, n_genes_in_adata).
Entries are 1-indexed ranks (1 = highest median-normalized
expression in that cell). Zero means the gene is not expressed
or not in the model vocabulary.
context_length : int
Effective context length (model_input_size, or model_input_size - 2
if special tokens are used). Genes with rank <= context_length
are "in context" for the model.
"""
if "ensembl_id" not in adata_obj.var.columns:
from helical.utils.mapping import map_gene_symbols_to_ensembl_ids
col = gene_names if gene_names != "index" else None
adata_obj = map_gene_symbols_to_ensembl_ids(adata_obj, col)

adata = sum_ensembl_ids(
adata_obj,
self.collapse_gene_ids,
self.gene_mapping_dict,
self.gene_token_dict,
file_format="h5ad",
chunk_size=self.chunk_size,
)

# Identify vocabulary genes (same filter as tokenize_anndata)
coding_miRNA_loc = np.where(
[self.genelist_dict.get(i, False) for i in adata.var["ensembl_id"]]
)[0]
norm_factor_vector = np.array(
[
self.gene_median_dict[i]
for i in adata.var["ensembl_id"].iloc[coding_miRNA_loc]
]
)

# Effective context length (account for CLS/EOS tokens)
context_length = (
self.model_input_size - 2 if self.special_token else self.model_input_size
)

if "filter_pass" in adata.obs.columns:
filter_pass_loc = np.where(
[i == 1 for i in adata.obs["filter_pass"]]
)[0]
else:
filter_pass_loc = np.arange(adata.shape[0])

n_cells = len(filter_pass_loc)
n_genes = adata.shape[1]
rank_matrix = np.zeros((adata.shape[0], n_genes), dtype=np.int32)

for i in range(0, n_cells, self.chunk_size):
idx = filter_pass_loc[i : i + self.chunk_size]
X_view = adata[idx, :].X[:, coding_miRNA_loc]

# Median-scale: e_{c,g} = r_{c,g} / m_g
X_scaled = sp.csr_matrix(X_view / norm_factor_vector.reshape(1, -1))

# Compute ranks per cell and write into dense matrix
for j in range(X_scaled.shape[0]):
row = X_scaled.getrow(j)
if row.nnz == 0:
continue

order = np.argsort(-row.data)
ranks_arr = np.empty_like(order, dtype=np.int32)
ranks_arr[order] = np.arange(1, len(order) + 1, dtype=np.int32)

orig_cols = coding_miRNA_loc[row.indices]
rank_matrix[idx[j], orig_cols] = ranks_arr

return rank_matrix, context_length

def tokenize_data(
self,
data_directory: Path | str,
Expand Down
Loading