Skip to content

Commit aefdc99

Browse files
tomvdwSeqIO
authored andcommitted
Optimize decoding in SentencePieceVocabulary
PiperOrigin-RevId: 864272339
1 parent db78942 commit aefdc99

1 file changed

Lines changed: 50 additions & 20 deletions

File tree

seqio/vocabularies.py

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def encode(self, s: Union[Sequence[int], str]) -> Sequence[int]:
9999
def _decode(self, ids):
100100
raise NotImplementedError
101101

102-
def decode(self, ids: Iterable[int]):
102+
def decode(self, ids: Iterable[int]) -> str:
103103
"""Detokenizes int32 iterable to a string, up through first EOS."""
104104
# A `tf.Tensor` is `Iterable` so it's valid to pass into this function.
105105
# However, iterating over a 1D EagerTensor will create a scalar EagerTensor
@@ -109,14 +109,30 @@ def decode(self, ids: Iterable[int]):
109109
ids: tf.Tensor = ids
110110
ids = ids.numpy().tolist()
111111

112-
clean_ids = list(ids)
113-
114-
if self.unk_id is not None:
115-
vocab_size = self._base_vocab_size
116-
clean_ids = [self.unk_id if i >= vocab_size else i for i in clean_ids]
117-
118-
if self.eos_id is not None and self.eos_id in clean_ids:
119-
clean_ids = clean_ids[: clean_ids.index(self.eos_id) + 1]
112+
unk_id = self.unk_id
113+
eos_id = self.eos_id
114+
vocab_size = self._base_vocab_size if unk_id is not None else None
115+
116+
clean_ids = []
117+
if vocab_size is not None:
118+
for i in ids:
119+
if i >= vocab_size:
120+
i = unk_id
121+
clean_ids.append(i)
122+
if i == eos_id:
123+
break
124+
else:
125+
for i in ids:
126+
clean_ids.append(i)
127+
if i == eos_id:
128+
break
129+
# clean_ids = []
130+
# for i in ids:
131+
# if vocab_size is not None and i >= vocab_size:
132+
# i = unk_id
133+
# clean_ids.append(i)
134+
# if i == eos_id:
135+
# break
120136

121137
return self._decode(clean_ids)
122138

@@ -415,6 +431,11 @@ def __init__(
415431
self._normalizer_spec_overrides = normalizer_spec_overrides
416432
self._reverse_extra_ids = reverse_extra_ids
417433
self._model: Optional[_ModelContext] = None
434+
self._cached_unk_id: Optional[int] = None
435+
self._cached_bos_id: Optional[int] = None
436+
self._cached_eos_id: Optional[int] = None
437+
self._cached_pad_id: Optional[int] = None
438+
self._cached_piece_size: Optional[int] = None
418439
self._use_fast_tokenizer = use_fast_tokenizer
419440

420441
super().__init__(extra_ids=extra_ids)
@@ -458,19 +479,24 @@ def _model_context(
458479
normalizer_spec_overrides_serialized,
459480
self._reverse_extra_ids,
460481
)
482+
self._cached_unk_id = self._model.tokenizer.unk_id()
483+
self._cached_bos_id = self._model.tokenizer.bos_id()
484+
self._cached_eos_id = self._model.tokenizer.eos_id()
485+
self._cached_pad_id = self._model.tokenizer.pad_id()
486+
self._cached_piece_size = self._model.tokenizer.GetPieceSize()
461487
return self._model
462488

463489
@property
464490
def bos_id(self) -> Optional[int]:
465-
return self.tokenizer.bos_id()
491+
return self._cached_bos_id if self._model else self.tokenizer.bos_id()
466492

467493
@property
468494
def eos_id(self) -> Optional[int]:
469-
return self.tokenizer.eos_id()
495+
return self._cached_eos_id if self._model else self.tokenizer.eos_id()
470496

471497
@property
472498
def unk_id(self) -> Optional[int]:
473-
return self.tokenizer.unk_id()
499+
return self._cached_unk_id if self._model else self.tokenizer.unk_id()
474500

475501
@property
476502
def sp_model(self) -> Optional[bytes]:
@@ -495,7 +521,11 @@ def tf_tokenizer(self):
495521

496522
@property
497523
def vocab_size(self):
498-
return self._base_vocab_size
524+
return (
525+
self._cached_piece_size
526+
if self._model
527+
else self.tokenizer.GetPieceSize()
528+
)
499529

500530
@property
501531
def _base_vocab_size(self):
@@ -504,7 +534,11 @@ def _base_vocab_size(self):
504534
Returns:
505535
an integer, the vocabulary size
506536
"""
507-
return self.tokenizer.GetPieceSize()
537+
return (
538+
self._cached_piece_size
539+
if self._model
540+
else self.tokenizer.GetPieceSize()
541+
)
508542

509543
def _encode(self, s: str) -> Sequence[int]:
510544
"""Encode a python string as a list of integers.
@@ -517,7 +551,7 @@ def _encode(self, s: str) -> Sequence[int]:
517551
"""
518552
return self.tokenizer.EncodeAsIds(s)
519553

520-
def _decode(self, ids):
554+
def _decode(self, ids: Sequence[int]) -> str:
521555
"""Decode a list of integers to a python string.
522556
523557
Args:
@@ -526,11 +560,7 @@ def _decode(self, ids):
526560
Returns:
527561
a string
528562
"""
529-
# convert all the extra ids (sentinels) to UNK=2
530-
unk_id = self.tokenizer.unk_id()
531-
piece_size = self.tokenizer.GetPieceSize()
532-
ids = [unk_id if i >= piece_size else int(i) for i in ids]
533-
return self.tokenizer.DecodeIds(ids)
563+
return self.tokenizer.DecodeIds(list(ids))
534564

535565
def _encode_tf(self, s):
536566
"""Encode a tf.Scalar string to a tf.Tensor.

0 commit comments

Comments
 (0)