Skip to main content

Encode and decode example with T5

Open In Colab

In [1]:
pip install -q transformers[sentencepiece]~=4.33.0
In [2]:
import IPython.display as ipd
import torch

torch.__version__
Out[2]:
'2.0.1+cu118'
In [3]:
from transformers import T5Tokenizer, T5EncoderModel, T5ForConditionalGeneration

model = T5ForConditionalGeneration.from_pretrained("t5-base")
encoder = T5EncoderModel.from_pretrained("t5-base")
tokenizer = T5Tokenizer.from_pretrained("t5-base", padding='max_length', truncation=True)
Downloading (…)lve/main/config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]
Downloading model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]
Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]
Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]
/usr/local/lib/python3.10/dist-packages/transformers/models/t5/tokenization_t5.py:220: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.
For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.
  warnings.warn(
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. If you see this, DO NOT PANIC! This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
In [6]:
body = "Sea turtles (superfamily Chelonioidea), sometimes called marine turtles,[3] are reptiles of the order Testudines and of the suborder Cryptodira. The seven existing species of sea turtles are the flatback, green, hawksbill, leatherback, loggerhead, Kemp's ridley, and olive ridley sea turtles.[4] All of the seven species listed above, except for the flatback, are present in US waters, and are listed as endangered and/or threatened under the Endangered Species Act.[5] The flatback itself exists in the waters of Australia, Papua New Guinea and Indonesia.[5] Sea turtles can be categorized as hard-shelled (cheloniid) or leathery-shelled (dermochelyid).[6] The only dermochelyid species of sea turtle is the leatherback.[6]"
inputs = [f"summarize: {body}"]

# Encode strings with T5.
encoding = tokenizer(inputs, return_tensors="pt", padding=True)
embeddings = model.encoder(**encoding)

# Perturb embeddings a little bit.
embeddings.last_hidden_state += torch.normal(mean=0.0, std=1e-3, size=embeddings.last_hidden_state.shape)

# Decode same embeddings with T5 back to text.
tokens = model.generate(encoder_outputs=embeddings)
tokenizer.batch_decode(tokens, skip_special_tokens=True)
Out[6]:
['the flatback, green, hawksbill, leatherback, loggerhead,']