@@ -99,7 +99,7 @@ def encode(self, s: Union[Sequence[int], str]) -> Sequence[int]:
9999 def _decode (self , ids ):
100100 raise NotImplementedError
101101
102- def decode (self , ids : Iterable [int ]):
102+ def decode (self , ids : Iterable [int ]) -> str :
103103 """Detokenizes int32 iterable to a string, up through first EOS."""
104104 # A `tf.Tensor` is `Iterable` so it's valid to pass into this function.
105105 # However, iterating over a 1D EagerTensor will create a scalar EagerTensor
@@ -109,14 +109,30 @@ def decode(self, ids: Iterable[int]):
109109 ids : tf .Tensor = ids
110110 ids = ids .numpy ().tolist ()
111111
112- clean_ids = list (ids )
113-
114- if self .unk_id is not None :
115- vocab_size = self ._base_vocab_size
116- clean_ids = [self .unk_id if i >= vocab_size else i for i in clean_ids ]
117-
118- if self .eos_id is not None and self .eos_id in clean_ids :
119- clean_ids = clean_ids [: clean_ids .index (self .eos_id ) + 1 ]
112+ unk_id = self .unk_id
113+ eos_id = self .eos_id
114+ vocab_size = self ._base_vocab_size if unk_id is not None else None
115+
116+ clean_ids = []
117+ if vocab_size is not None :
118+ for i in ids :
119+ if i >= vocab_size :
120+ i = unk_id
121+ clean_ids .append (i )
122+ if i == eos_id :
123+ break
124+ else :
125+ for i in ids :
126+ clean_ids .append (i )
127+ if i == eos_id :
128+ break
129+ # clean_ids = []
130+ # for i in ids:
131+ # if vocab_size is not None and i >= vocab_size:
132+ # i = unk_id
133+ # clean_ids.append(i)
134+ # if i == eos_id:
135+ # break
120136
121137 return self ._decode (clean_ids )
122138
@@ -415,6 +431,11 @@ def __init__(
415431 self ._normalizer_spec_overrides = normalizer_spec_overrides
416432 self ._reverse_extra_ids = reverse_extra_ids
417433 self ._model : Optional [_ModelContext ] = None
434+ self ._cached_unk_id : Optional [int ] = None
435+ self ._cached_bos_id : Optional [int ] = None
436+ self ._cached_eos_id : Optional [int ] = None
437+ self ._cached_pad_id : Optional [int ] = None
438+ self ._cached_piece_size : Optional [int ] = None
418439 self ._use_fast_tokenizer = use_fast_tokenizer
419440
420441 super ().__init__ (extra_ids = extra_ids )
@@ -458,19 +479,24 @@ def _model_context(
458479 normalizer_spec_overrides_serialized ,
459480 self ._reverse_extra_ids ,
460481 )
482+ self ._cached_unk_id = self ._model .tokenizer .unk_id ()
483+ self ._cached_bos_id = self ._model .tokenizer .bos_id ()
484+ self ._cached_eos_id = self ._model .tokenizer .eos_id ()
485+ self ._cached_pad_id = self ._model .tokenizer .pad_id ()
486+ self ._cached_piece_size = self ._model .tokenizer .GetPieceSize ()
461487 return self ._model
462488
463489 @property
464490 def bos_id (self ) -> Optional [int ]:
465- return self .tokenizer .bos_id ()
491+ return self ._cached_bos_id if self . _model else self . tokenizer .bos_id ()
466492
467493 @property
468494 def eos_id (self ) -> Optional [int ]:
469- return self .tokenizer .eos_id ()
495+ return self ._cached_eos_id if self . _model else self . tokenizer .eos_id ()
470496
471497 @property
472498 def unk_id (self ) -> Optional [int ]:
473- return self .tokenizer .unk_id ()
499+ return self ._cached_unk_id if self . _model else self . tokenizer .unk_id ()
474500
475501 @property
476502 def sp_model (self ) -> Optional [bytes ]:
@@ -495,7 +521,11 @@ def tf_tokenizer(self):
495521
496522 @property
497523 def vocab_size (self ):
498- return self ._base_vocab_size
524+ return (
525+ self ._cached_piece_size
526+ if self ._model
527+ else self .tokenizer .GetPieceSize ()
528+ )
499529
500530 @property
501531 def _base_vocab_size (self ):
@@ -504,7 +534,11 @@ def _base_vocab_size(self):
504534 Returns:
505535 an integer, the vocabulary size
506536 """
507- return self .tokenizer .GetPieceSize ()
537+ return (
538+ self ._cached_piece_size
539+ if self ._model
540+ else self .tokenizer .GetPieceSize ()
541+ )
508542
509543 def _encode (self , s : str ) -> Sequence [int ]:
510544 """Encode a python string as a list of integers.
@@ -517,7 +551,7 @@ def _encode(self, s: str) -> Sequence[int]:
517551 """
518552 return self .tokenizer .EncodeAsIds (s )
519553
520- def _decode (self , ids ) :
554+ def _decode (self , ids : Sequence [ int ]) -> str :
521555 """Decode a list of integers to a python string.
522556
523557 Args:
@@ -526,11 +560,7 @@ def _decode(self, ids):
526560 Returns:
527561 a string
528562 """
529- # convert all the extra ids (sentinels) to UNK=2
530- unk_id = self .tokenizer .unk_id ()
531- piece_size = self .tokenizer .GetPieceSize ()
532- ids = [unk_id if i >= piece_size else int (i ) for i in ids ]
533- return self .tokenizer .DecodeIds (ids )
563+ return self .tokenizer .DecodeIds (list (ids ))
534564
535565 def _encode_tf (self , s ):
536566 """Encode a tf.Scalar string to a tf.Tensor.
0 commit comments