| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | class YourTextGenerationModel(nn.Module): |
| | def __init__(self, vocab_size, embedding_dim, hidden_dim): |
| | super(YourTextGenerationModel, self).__init__() |
| | self.embedding = nn.Embedding(vocab_size, embedding_dim) |
| | self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True) |
| | self.linear = nn.Linear(hidden_dim, vocab_size) |
| |
|
| | def forward(self, x): |
| | embedded = self.embedding(x) |
| | lstm_out, _ = self.lstm(embedded) |
| | output = self.linear(lstm_out) |
| | return output |
| |
|
| | def generate_text(self, prompt): |
| | |
| | |
| | generated_text = "Generated text for: " + prompt |
| | return generated_text |
| |
|
| | if __name__ == "__main__": |
| | |
| | vocab_size = 10000 |
| | embedding_dim = 128 |
| | hidden_dim = 256 |
| |
|
| | your_model = YourTextGenerationModel(vocab_size, embedding_dim, hidden_dim) |
| | prompt = "Once upon a time" |
| | generated_text = your_model.generate_text(prompt) |
| |
|
| | print("Input Prompt:", prompt) |
| | print("Generated Text:", generated_text) |