File size: 1,249 Bytes
db59d0b
8cbdb92
 
 
 
 
db59d0b
8cbdb92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import tensorflow as tf
from tensorflow.keras.layers import Input, Embedding, LayerNormalization, MultiHeadAttention, Dense, Add, Dropout, Layer
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy
import numpy as np

class VoidChatModel(tf.keras.Model):
    def __init__(self, vocab_size, seq_len, num_layers=6, num_heads=8, emb_dim=512, mlp_dim=2048, dropout_rate=0.1):
        super(VoidChatModel, self).__init__()
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.emb_dim = emb_dim
        self.mlp_dim = mlp_dim
        self.dropout_rate = dropout_rate
        
        # Embedding layer
        self.embedding = Embedding(input_dim=vocab_size, output_dim=emb_dim)
        
        # Transformer layers
        self.transformer_blocks = [TransformerBlock(num_heads, emb_dim, mlp_dim, dropout_rate) for _ in range(num_layers)]
        
        # Output layer
        self.output_layer = Dense(vocab_size, activation='softmax')
    
    def call(self, input_ids, training=False):
        # Embedding layer
        x = self.embedding(input_ids)