Namesformer
Namesformer#
Before we get into the lecture you can play with the trained model here: Namesformer Streamlit app.
Inspired by Andrej Karpathy lecture makemore that contains english name generation.
The code was fully writen using ChatGPT with minimal corrections. My first query was:
I am preparing a lecture for my students on AI basics. They already know how to use attention in PyTorch to create self-attention layers. What I want to explain them is how to make a simplest possible transformer architecture (with minimal amount of code).
As a dataset I will use a csv with names:
john
peter
mike
...
And the goal will be to generate more names that sound name-like.
Give me an implementation with PyTorch trying to keep it as minimal as possible.
After that I had to ask for couple corrections, like avoiding using Transformer layer, adding comments, fixing a bug in token indexing. All were relatively easy to spot and in less than an hour this notebook was generating plausibly sounding names.
I decided to replace original dataset since I found a list of Lithuanian names that are easy to extract from vardai.vlkk.lt using the following code snippet:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import requests
from bs4 import BeautifulSoup
names = []
for key in ['a', 'b', 'c', 'c-2', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
'm', 'n', 'o', 'p', 'r', 's', 's-2', 't', 'u', 'v', 'z', 'z-2']:
url = f'https://vardai.vlkk.lt/sarasas/{key}/'
response = requests.get(url)
soup = BeautifulSoup(response.text, 'html.parser')
links = soup.find_all('a', class_='names_list__links names_list__links--man')
names += [name.text for name in links]
np.savetxt('vardai.txt', names, fmt='%s', header='name', comments='', newline='\n')
If you want to play with english names download them from here and use names.txt instead of vardai.txt.
Let’s add a space at the end to mark the end of the name. We will need a dictionary that encodes characters to integers and back, thus let’s wrap that logic in a class.
class NameDataset(Dataset):
def __init__(self, csv_file):
self.names = pd.read_csv(csv_file)['name'].values
self.chars = sorted(list(set(''.join(self.names) + ' '))) # Including a padding character
self.char_to_int = {c: i for i, c in enumerate(self.chars)}
self.int_to_char = {i: c for c, i in self.char_to_int.items()}
self.vocab_size = len(self.chars)
def __len__(self):
return len(self.names)
def __getitem__(self, idx):
name = self.names[idx] + ' ' # Adding padding character at the end
encoded_name = [self.char_to_int[char] for char in name]
return torch.tensor(encoded_name)
dataset = NameDataset('vardai.txt')
len(dataset)
3850
dataset[0]
tensor([ 1, 82, 24, 23, 40, 0])
[dataset.int_to_char[int(char)] for char in dataset[0]]
['A', '̃', 'b', 'a', 's', ' ']
Note that this dataset is not simple since it uses accentuation symbols and capital letters. Let’s intentionally keep it like this and see if the model can figure it out. When you do it yourself feel free to remove accentuation and use only lower case letters.
We need a way to construct padded batches.
# Custom collate function for padding
def pad_collate(batch):
padded_seqs = pad_sequence(batch, batch_first=True, padding_value=0)
input_seq = padded_seqs[:, :-1]
target_seq = padded_seqs[:, 1:]
return input_seq, target_seq
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=pad_collate)
Make sure you understand what this generates and why.
next(iter(dataloader))
(tensor([[ 1, 75, 26, 31, 40, 0, 0, 0, 0, 0, 0, 0, 0],
[16, 23, 81, 42, 34, 31, 40, 0, 0, 0, 0, 0, 0],
[17, 23, 82, 40, 35, 23, 36, 41, 23, 40, 0, 0, 0],
[ 7, 31, 80, 36, 41, 23, 42, 41, 23, 40, 0, 0, 0],
[21, 31, 80, 33, 31, 40, 0, 0, 0, 0, 0, 0, 0],
[ 8, 23, 39, 37, 34, 26, 0, 0, 0, 0, 0, 0, 0],
[ 5, 80, 34, 31, 36, 29, 23, 40, 0, 0, 0, 0, 0],
[ 7, 37, 81, 41, 23, 39, 23, 40, 0, 0, 0, 0, 0],
[13, 37, 26, 27, 80, 40, 41, 23, 40, 0, 0, 0, 0],
[14, 37, 23, 30, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[ 1, 34, 27, 33, 40, 0, 0, 0, 0, 0, 0, 0, 0],
[ 5, 80, 26, 31, 40, 0, 0, 0, 0, 0, 0, 0, 0],
[ 4, 37, 80, 39, 23, 40, 0, 0, 0, 0, 0, 0, 0],
[16, 23, 82, 41, 39, 31, 40, 0, 0, 0, 0, 0, 0],
[ 5, 28, 29, 27, 36, 31, 32, 42, 40, 0, 0, 0, 0],
[ 2, 23, 34, 41, 39, 23, 35, 31, 27, 82, 32, 42, 40],
[21, 31, 25, 41, 37, 39, 23, 40, 0, 0, 0, 0, 0],
[ 7, 58, 39, 79, 23, 26, 23, 40, 0, 0, 0, 0, 0],
[18, 33, 23, 31, 82, 26, 39, 31, 42, 40, 0, 0, 0],
[ 5, 31, 82, 43, 23, 39, 26, 23, 40, 0, 0, 0, 0],
[22, 27, 28, 31, 39, 31, 80, 36, 23, 40, 0, 0, 0],
[ 5, 36, 26, 39, 31, 27, 32, 42, 40, 0, 0, 0, 0],
[18, 23, 43, 70, 34, 31, 32, 42, 40, 0, 0, 0, 0],
[16, 23, 39, 28, 31, 39, 31, 32, 42, 40, 0, 0, 0],
[17, 70, 82, 26, 23, 40, 0, 0, 0, 0, 0, 0, 0],
[17, 23, 35, 42, 36, 23, 40, 0, 0, 0, 0, 0, 0],
[14, 23, 39, 46, 35, 23, 36, 41, 23, 40, 0, 0, 0],
[ 7, 27, 82, 26, 35, 23, 36, 41, 23, 40, 0, 0, 0],
[ 1, 31, 82, 29, 31, 36, 41, 23, 40, 0, 0, 0, 0],
[16, 34, 23, 41, 37, 36, 0, 0, 0, 0, 0, 0, 0],
[ 7, 43, 31, 80, 26, 35, 23, 36, 41, 23, 40, 0, 0],
[18, 23, 81, 36, 29, 27, 26, 23, 40, 0, 0, 0, 0]]),
tensor([[75, 26, 31, 40, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[23, 81, 42, 34, 31, 40, 0, 0, 0, 0, 0, 0, 0],
[23, 82, 40, 35, 23, 36, 41, 23, 40, 0, 0, 0, 0],
[31, 80, 36, 41, 23, 42, 41, 23, 40, 0, 0, 0, 0],
[31, 80, 33, 31, 40, 0, 0, 0, 0, 0, 0, 0, 0],
[23, 39, 37, 34, 26, 0, 0, 0, 0, 0, 0, 0, 0],
[80, 34, 31, 36, 29, 23, 40, 0, 0, 0, 0, 0, 0],
[37, 81, 41, 23, 39, 23, 40, 0, 0, 0, 0, 0, 0],
[37, 26, 27, 80, 40, 41, 23, 40, 0, 0, 0, 0, 0],
[37, 23, 30, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[34, 27, 33, 40, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[80, 26, 31, 40, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[37, 80, 39, 23, 40, 0, 0, 0, 0, 0, 0, 0, 0],
[23, 82, 41, 39, 31, 40, 0, 0, 0, 0, 0, 0, 0],
[28, 29, 27, 36, 31, 32, 42, 40, 0, 0, 0, 0, 0],
[23, 34, 41, 39, 23, 35, 31, 27, 82, 32, 42, 40, 0],
[31, 25, 41, 37, 39, 23, 40, 0, 0, 0, 0, 0, 0],
[58, 39, 79, 23, 26, 23, 40, 0, 0, 0, 0, 0, 0],
[33, 23, 31, 82, 26, 39, 31, 42, 40, 0, 0, 0, 0],
[31, 82, 43, 23, 39, 26, 23, 40, 0, 0, 0, 0, 0],
[27, 28, 31, 39, 31, 80, 36, 23, 40, 0, 0, 0, 0],
[36, 26, 39, 31, 27, 32, 42, 40, 0, 0, 0, 0, 0],
[23, 43, 70, 34, 31, 32, 42, 40, 0, 0, 0, 0, 0],
[23, 39, 28, 31, 39, 31, 32, 42, 40, 0, 0, 0, 0],
[70, 82, 26, 23, 40, 0, 0, 0, 0, 0, 0, 0, 0],
[23, 35, 42, 36, 23, 40, 0, 0, 0, 0, 0, 0, 0],
[23, 39, 46, 35, 23, 36, 41, 23, 40, 0, 0, 0, 0],
[27, 82, 26, 35, 23, 36, 41, 23, 40, 0, 0, 0, 0],
[31, 82, 29, 31, 36, 41, 23, 40, 0, 0, 0, 0, 0],
[34, 23, 41, 37, 36, 0, 0, 0, 0, 0, 0, 0, 0],
[43, 31, 80, 26, 35, 23, 36, 41, 23, 40, 0, 0, 0],
[23, 81, 36, 29, 27, 26, 23, 40, 0, 0, 0, 0, 0]]))
Our transformer will be based on the self-attention.
class MinimalTransformer(nn.Module):
def __init__(self, vocab_size, embed_size, num_heads, forward_expansion):
super(MinimalTransformer, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.positional_encoding = nn.Parameter(torch.randn(1, 100, embed_size))
self.encoder_layer = nn.TransformerEncoderLayer(d_model=embed_size, nhead=num_heads)
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=1)
self.output_layer = nn.Linear(embed_size, vocab_size)
def forward(self, x):
positions = torch.arange(0, x.size(1)).unsqueeze(0)
x = self.embed(x) + self.positional_encoding[:, :x.size(1), :]
x = self.transformer_encoder(x)
x = self.output_layer(x)
return x
# Training Loop
def train_model(model, dataloader, epochs=10):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
for epoch in range(epochs):
model.train() # Ensure the model is in training mode
total_loss = 0.0
batch_count = 0
for batch_idx, (input_seq, target_seq) in enumerate(dataloader):
optimizer.zero_grad()
output = model(input_seq)
loss = criterion(output.transpose(1, 2), target_seq)
loss.backward()
optimizer.step()
total_loss += loss.item()
batch_count += 1
average_loss = total_loss / batch_count
print(f'Epoch {epoch+1}, Average Loss: {average_loss}')
model = MinimalTransformer(vocab_size=dataset.vocab_size, embed_size=128, num_heads=8, forward_expansion=4)
train_model(model, dataloader)
Epoch 1, Average Loss: 1.560861560923994
Epoch 2, Average Loss: 1.3083962241480174
Epoch 3, Average Loss: 1.275466939634528
Epoch 4, Average Loss: 1.2612562283011508
Epoch 5, Average Loss: 1.2426953862521275
Epoch 6, Average Loss: 1.2460837374048785
Epoch 7, Average Loss: 1.2364239564611892
Epoch 8, Average Loss: 1.2283017211709142
Epoch 9, Average Loss: 1.226946584941927
Epoch 10, Average Loss: 1.2349231972182093
And generate a name by predicing the next letter. We will use the fact that model returns logits that can be turned into probabilities which can later be used to sample a character from the probability distribution.
def sample(model, dataset, start_str='a', max_length=20):
model.eval() # Switch to evaluation mode
with torch.no_grad():
# Convert start string to tensor
chars = [dataset.char_to_int[c] for c in start_str]
input_seq = torch.tensor(chars).unsqueeze(0) # Add batch dimension
output_name = start_str
for _ in range(max_length - len(start_str)):
output = model(input_seq)
# Get the last character from the output
probabilities = torch.softmax(output[0, -1], dim=0)
# Sample a character from the probability distribution
next_char_idx = torch.multinomial(probabilities, 1).item()
next_char = dataset.int_to_char[next_char_idx]
if next_char == ' ': # Assume ' ' is your end-of-sequence character
break
output_name += next_char
# Update the input sequence for the next iteration
input_seq = torch.cat([input_seq, torch.tensor([[next_char_idx]])], dim=1)
return output_name
# After training your model, generate a name starting with a specific letter
for _ in range(10):
generated_name = sample(model, dataset, start_str='R')
print(generated_name)
Rãvelijovas
Rãgmenas
Retuvytinton
Raĩnas
Rĩtris
Ranolìas
Roanis
Ranautontas
Rõmicis
Ralmas
Not bad! Note that this name is not in our names list.
generated_name
'Ralmas'
generated_name + ' ' in names
False
Let’s train for longer.
train_model(model, dataloader, epochs=200)
Epoch 1, Average Loss: 1.2292306038958967
Epoch 2, Average Loss: 1.2296753338545807
Epoch 3, Average Loss: 1.2113789906186505
Epoch 4, Average Loss: 1.2123224966782185
Epoch 5, Average Loss: 1.2134187822499551
Epoch 6, Average Loss: 1.2100671367211775
Epoch 7, Average Loss: 1.2046438445729657
Epoch 8, Average Loss: 1.216528780696806
Epoch 9, Average Loss: 1.2090119542169178
Epoch 10, Average Loss: 1.204182753385591
Epoch 11, Average Loss: 1.19694076096716
Epoch 12, Average Loss: 1.2097524177929586
Epoch 13, Average Loss: 1.2037021379825497
Epoch 14, Average Loss: 1.2041516210422043
Epoch 15, Average Loss: 1.2032628212093321
Epoch 16, Average Loss: 1.1915799916283158
Epoch 17, Average Loss: 1.191510959105058
Epoch 18, Average Loss: 1.204029481273052
Epoch 19, Average Loss: 1.1954617096372873
Epoch 20, Average Loss: 1.1981768662279302
Epoch 21, Average Loss: 1.2015780131678937
Epoch 22, Average Loss: 1.2011014198468737
Epoch 23, Average Loss: 1.195754015248669
Epoch 24, Average Loss: 1.1914489850525027
Epoch 25, Average Loss: 1.1945332929122547
Epoch 26, Average Loss: 1.1859642592343418
Epoch 27, Average Loss: 1.1900481729468038
Epoch 28, Average Loss: 1.1984371597116643
Epoch 29, Average Loss: 1.1915138644620407
Epoch 30, Average Loss: 1.1930926890412639
Epoch 31, Average Loss: 1.1906916233133678
Epoch 32, Average Loss: 1.1895758549043955
Epoch 33, Average Loss: 1.1972257815116694
Epoch 34, Average Loss: 1.1903804575116181
Epoch 35, Average Loss: 1.1944811595372917
Epoch 36, Average Loss: 1.1970710424352284
Epoch 37, Average Loss: 1.1916139130749979
Epoch 38, Average Loss: 1.1847474373076572
Epoch 39, Average Loss: 1.1924392846990224
Epoch 40, Average Loss: 1.197125053602802
Epoch 41, Average Loss: 1.1875630809255868
Epoch 42, Average Loss: 1.1958131982275277
Epoch 43, Average Loss: 1.1888317726860362
Epoch 44, Average Loss: 1.1850869113748723
Epoch 45, Average Loss: 1.1943016303472282
Epoch 46, Average Loss: 1.1864741745073932
Epoch 47, Average Loss: 1.1840156661577461
Epoch 48, Average Loss: 1.1876781114861985
Epoch 49, Average Loss: 1.1812076967609815
Epoch 50, Average Loss: 1.1807344807080986
Epoch 51, Average Loss: 1.1867172599824007
Epoch 52, Average Loss: 1.184158510905652
Epoch 53, Average Loss: 1.1861653234347824
Epoch 54, Average Loss: 1.18255559324233
Epoch 55, Average Loss: 1.187564582864115
Epoch 56, Average Loss: 1.1790710707341343
Epoch 57, Average Loss: 1.1837314874672693
Epoch 58, Average Loss: 1.1867934169848102
Epoch 59, Average Loss: 1.186100725300056
Epoch 60, Average Loss: 1.1863626276165986
Epoch 61, Average Loss: 1.1842943878213237
Epoch 62, Average Loss: 1.1788255060014645
Epoch 63, Average Loss: 1.189123210335566
Epoch 64, Average Loss: 1.1844587252159748
Epoch 65, Average Loss: 1.183324110409445
Epoch 66, Average Loss: 1.1821246408233959
Epoch 67, Average Loss: 1.1777123020700186
Epoch 68, Average Loss: 1.1788903292545603
Epoch 69, Average Loss: 1.186970629967934
Epoch 70, Average Loss: 1.1829478819508197
Epoch 71, Average Loss: 1.1873527243117656
Epoch 72, Average Loss: 1.1788276197496526
Epoch 73, Average Loss: 1.180316432448458
Epoch 74, Average Loss: 1.1803383127716947
Epoch 75, Average Loss: 1.1891687940960087
Epoch 76, Average Loss: 1.1843605652328366
Epoch 77, Average Loss: 1.17897569837649
Epoch 78, Average Loss: 1.1766234498378658
Epoch 79, Average Loss: 1.175462582387215
Epoch 80, Average Loss: 1.1749394437498297
Epoch 81, Average Loss: 1.1846170775161302
Epoch 82, Average Loss: 1.1909018454472882
Epoch 83, Average Loss: 1.1792369213971226
Epoch 84, Average Loss: 1.1763497935838936
Epoch 85, Average Loss: 1.1755769646857395
Epoch 86, Average Loss: 1.1799335519144358
Epoch 87, Average Loss: 1.181434384554871
Epoch 88, Average Loss: 1.1844923865696615
Epoch 89, Average Loss: 1.1836248995843999
Epoch 90, Average Loss: 1.1798809003238835
Epoch 91, Average Loss: 1.178152390748016
Epoch 92, Average Loss: 1.178901397984875
Epoch 93, Average Loss: 1.1763924128753094
Epoch 94, Average Loss: 1.1776803554582203
Epoch 95, Average Loss: 1.1745359493681222
Epoch 96, Average Loss: 1.190951465575163
Epoch 97, Average Loss: 1.168050891111705
Epoch 98, Average Loss: 1.180815726272331
Epoch 99, Average Loss: 1.179542955288217
Epoch 100, Average Loss: 1.1771875082953902
Epoch 101, Average Loss: 1.169836976804024
Epoch 102, Average Loss: 1.1732540869515788
Epoch 103, Average Loss: 1.1739061696470277
Epoch 104, Average Loss: 1.1792582514857457
Epoch 105, Average Loss: 1.1752005519945758
Epoch 106, Average Loss: 1.1760616937944712
Epoch 107, Average Loss: 1.170486671865479
Epoch 108, Average Loss: 1.1782956773584539
Epoch 109, Average Loss: 1.1726476963886545
Epoch 110, Average Loss: 1.1795510324564846
Epoch 111, Average Loss: 1.1746162230318242
Epoch 112, Average Loss: 1.1758471683037182
Epoch 113, Average Loss: 1.1813066552493199
Epoch 114, Average Loss: 1.172285570093423
Epoch 115, Average Loss: 1.176815803385963
Epoch 116, Average Loss: 1.1789336785797244
Epoch 117, Average Loss: 1.1769179090980655
Epoch 118, Average Loss: 1.1783951860814055
Epoch 119, Average Loss: 1.1832648427033228
Epoch 120, Average Loss: 1.1782682405030431
Epoch 121, Average Loss: 1.17259880925013
Epoch 122, Average Loss: 1.1715694592018757
Epoch 123, Average Loss: 1.1720148526932583
Epoch 124, Average Loss: 1.1784934568996273
Epoch 125, Average Loss: 1.1775894431043263
Epoch 126, Average Loss: 1.1746587122767425
Epoch 127, Average Loss: 1.1776675808528239
Epoch 128, Average Loss: 1.1725553574640888
Epoch 129, Average Loss: 1.177438148782273
Epoch 130, Average Loss: 1.176710447496619
Epoch 131, Average Loss: 1.1731141327826444
Epoch 132, Average Loss: 1.1767307495282702
Epoch 133, Average Loss: 1.1693934473124417
Epoch 134, Average Loss: 1.1733752405347904
Epoch 135, Average Loss: 1.1659849537305595
Epoch 136, Average Loss: 1.179626563363824
Epoch 137, Average Loss: 1.1743027917609727
Epoch 138, Average Loss: 1.1785855022343723
Epoch 139, Average Loss: 1.1783450354229321
Epoch 140, Average Loss: 1.1774260229315638
Epoch 141, Average Loss: 1.1715978881544318
Epoch 142, Average Loss: 1.172811331335178
Epoch 143, Average Loss: 1.1768728031599818
Epoch 144, Average Loss: 1.1773389183785306
Epoch 145, Average Loss: 1.1743187510277615
Epoch 146, Average Loss: 1.1694534598303234
Epoch 147, Average Loss: 1.1702861293288302
Epoch 148, Average Loss: 1.1727704543712711
Epoch 149, Average Loss: 1.176124067838527
Epoch 150, Average Loss: 1.1815643192322787
Epoch 151, Average Loss: 1.1782729832594059
Epoch 152, Average Loss: 1.1701082518278074
Epoch 153, Average Loss: 1.1694987559121501
Epoch 154, Average Loss: 1.1757246812513051
Epoch 155, Average Loss: 1.1738169045487712
Epoch 156, Average Loss: 1.1750259355080028
Epoch 157, Average Loss: 1.1777134389916728
Epoch 158, Average Loss: 1.1760852149695404
Epoch 159, Average Loss: 1.1643548263005974
Epoch 160, Average Loss: 1.1709187139164319
Epoch 161, Average Loss: 1.1754598174213378
Epoch 162, Average Loss: 1.1712363441128375
Epoch 163, Average Loss: 1.1688077036014273
Epoch 164, Average Loss: 1.171680351919379
Epoch 165, Average Loss: 1.167231542512405
Epoch 166, Average Loss: 1.1670489754558595
Epoch 167, Average Loss: 1.170683554873979
Epoch 168, Average Loss: 1.1787078419992747
Epoch 169, Average Loss: 1.171857633373954
Epoch 170, Average Loss: 1.1711103881686187
Epoch 171, Average Loss: 1.1698308936820543
Epoch 172, Average Loss: 1.1775098812481588
Epoch 173, Average Loss: 1.1744549757192944
Epoch 174, Average Loss: 1.17041007644874
Epoch 175, Average Loss: 1.17597955760877
Epoch 176, Average Loss: 1.18260009525236
Epoch 177, Average Loss: 1.1756152614089084
Epoch 178, Average Loss: 1.1773499137113903
Epoch 179, Average Loss: 1.17570225709726
Epoch 180, Average Loss: 1.1799813922771738
Epoch 181, Average Loss: 1.1755351079396965
Epoch 182, Average Loss: 1.1673592575325453
Epoch 183, Average Loss: 1.1690961984563466
Epoch 184, Average Loss: 1.1743551965587395
Epoch 185, Average Loss: 1.1734651322207175
Epoch 186, Average Loss: 1.1740981872416725
Epoch 187, Average Loss: 1.1676894603681958
Epoch 188, Average Loss: 1.1670344974383835
Epoch 189, Average Loss: 1.1741712733733753
Epoch 190, Average Loss: 1.1770059106763728
Epoch 191, Average Loss: 1.1694683334058966
Epoch 192, Average Loss: 1.167220690526253
Epoch 193, Average Loss: 1.1743362151886807
Epoch 194, Average Loss: 1.1695056434505242
Epoch 195, Average Loss: 1.1659075292673977
Epoch 196, Average Loss: 1.180610185319727
Epoch 197, Average Loss: 1.1654542231362712
Epoch 198, Average Loss: 1.1696561141447588
Epoch 199, Average Loss: 1.1685029637715048
Epoch 200, Average Loss: 1.1642742590470747
for _ in range(10):
generated_name = sample(model, dataset, start_str='R')
print(generated_name)
Rãšmaus
Ralionijus
Raydòlonas
Rãvijus
Reonaldas
Rijuas
Ror
Raĩslãkas
Rìrmoldas
Rõviudas
If we want the model to be more creative we can add temperature/creativity control.
Question: does temparature increase or decrease model creativity? What is min/max value?
def sample(model, dataset, start_str='a', max_length=20, temperature=1.0):
assert temperature > 0, "Temperature must be greater than 0"
model.eval() # Switch model to evaluation mode
with torch.no_grad():
# Convert start string to tensor
chars = [dataset.char_to_int[c] for c in start_str]
input_seq = torch.tensor(chars).unsqueeze(0) # Add batch dimension
output_name = start_str
for _ in range(max_length - len(start_str)):
output = model(input_seq)
# Apply temperature scaling
logits = output[0, -1] / temperature
probabilities = torch.softmax(logits, dim=0)
# Sample a character from the probability distribution
next_char_idx = torch.multinomial(probabilities, 1).item()
next_char = dataset.int_to_char[next_char_idx]
if next_char == ' ': # Assume ' ' is your end-of-sequence character
break
output_name += next_char
# Update the input sequence for the next iteration
input_seq = torch.cat([input_seq, torch.tensor([[next_char_idx]])], dim=1)
return output_name
# Example usage with different temperatures
print('More confident:')
for _ in range(10):
print(' ', sample(model, dataset, start_str='R', temperature=0.5)) # More confident
print('\nMore diverse/creative:')
for _ in range(10):
print(' ', sample(model, dataset, start_str='R', temperature=1.5)) # More diverse
More confident:
Reraras
Raugìlas
Raũtas
Ravìmas
Rìlijus
Rãtas
Rìlius
Reris
Rãgas
Rìlijus
More diverse/creative:
Rntemijus
Rimtžvaus
Romènis
Romūdiutas
Ruolienas
Rẽdỹjuis
Reapẽndas
Rū̃drans
Rivì
Rámìk
Here we go, we have a Lithuanian name generator!
Next we can save the model and with some help from ChatGPT build a simple Streamlit app (https://namesformer.streamlit.app/).
TASK: add female names to the dataset, retrain the model (or make a 2nd one) and create your own Streamlit app (you do not need to have names leaderboard, that requires a database). Any improvement are welcome.
torch.save(model, 'namesformer_model.pt')