-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodel.py
More file actions
29 lines (22 loc) · 809 Bytes
/
model.py
File metadata and controls
29 lines (22 loc) · 809 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
# from transformers import T5Config, T5ForConditionalGeneration
from torch import nn
class AbstractModel(nn.Module):
def __init__(
self,
config,
dataset,
tokenizer,
):
super(AbstractModel, self).__init__()
self.config = config
self.dataset = dataset
self.tokenizer = tokenizer
@property
def n_parameters(self):
total_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
return f'Total number of trainable parameters: {total_params}'
def calculate_loss(self, batch):
raise NotImplementedError('calculate_loss method must be implemented.')
def generate(self, batch, n_return_sequences=1):
raise NotImplementedError('predict method must be implemented.')