Skip to content

Commit 90cc9e9

Browse files
committed
update Selene with HeartENN model arch
1 parent cfc9671 commit 90cc9e9

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed

models/heartenn.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""
2+
HeartENN architecture (Richter et al., 2020).
3+
"""
4+
import numpy as np
5+
import torch
6+
import torch.nn as nn
7+
8+
9+
class HeartENN(nn.Module):
10+
def __init__(self, sequence_length, n_genomic_features):
11+
"""
12+
Parameters
13+
----------
14+
sequence_length : int
15+
Length of sequence context on which to train.
16+
n_genomic_features : int
17+
The number of chromatin features to predict.
18+
19+
Attributes
20+
----------
21+
conv_net : torch.nn.Sequential
22+
classifier : torch.nn.Sequential
23+
24+
"""
25+
super(HeartENN, self).__init__()
26+
conv_kernel_size = 8
27+
pool_kernel_size = 4
28+
29+
self.conv_net = nn.Sequential(
30+
nn.Conv1d(4, 60, kernel_size=conv_kernel_size),
31+
nn.ReLU(inplace=True),
32+
nn.Conv1d(60, 60, kernel_size=conv_kernel_size),
33+
nn.ReLU(inplace=True),
34+
nn.MaxPool1d(
35+
kernel_size=pool_kernel_size, stride=pool_kernel_size),
36+
nn.BatchNorm1d(60),
37+
38+
nn.Conv1d(60, 80, kernel_size=conv_kernel_size),
39+
nn.ReLU(inplace=True),
40+
nn.Conv1d(80, 80, kernel_size=conv_kernel_size),
41+
nn.ReLU(inplace=True),
42+
nn.MaxPool1d(
43+
kernel_size=pool_kernel_size, stride=pool_kernel_size),
44+
nn.BatchNorm1d(80),
45+
nn.Dropout(p=0.4),
46+
47+
nn.Conv1d(80, 240, kernel_size=conv_kernel_size),
48+
nn.ReLU(inplace=True),
49+
nn.Conv1d(240, 240, kernel_size=conv_kernel_size),
50+
nn.ReLU(inplace=True),
51+
nn.BatchNorm1d(240),
52+
nn.Dropout(p=0.6))
53+
54+
reduce_by = 2 * (conv_kernel_size - 1)
55+
pool_kernel_size = float(pool_kernel_size)
56+
self._n_channels = int(
57+
np.floor(
58+
(np.floor(
59+
(sequence_length - reduce_by) / pool_kernel_size)
60+
- reduce_by) / pool_kernel_size)
61+
- reduce_by)
62+
self.classifier = nn.Sequential(
63+
nn.Linear(240 * self._n_channels, n_genomic_features),
64+
nn.ReLU(inplace=True),
65+
nn.BatchNorm1d(n_genomic_features),
66+
nn.Linear(n_genomic_features, n_genomic_features),
67+
nn.Sigmoid())
68+
69+
def forward(self, x):
70+
"""Forward propagation of a batch.i
71+
72+
"""
73+
for layer in self.conv_net.children():
74+
if isinstance(layer, nn.Conv1d):
75+
layer.weight.data.renorm_(2, 0, 0.9)
76+
for layer in self.classifier.children():
77+
if isinstance(layer, nn.Linear):
78+
layer.weight.data.renorm_(2, 0, 0.9)
79+
out = self.conv_net(x)
80+
reshape_out = out.view(out.size(0), 240 * self._n_channels)
81+
predict = self.classifier(reshape_out)
82+
return predict
83+
84+
def criterion():
85+
return nn.BCELoss()
86+
87+
def get_optimizer(lr):
88+
return (torch.optim.SGD,
89+
{"lr": lr, "weight_decay": 1e-6, "momentum": 0.9})

0 commit comments

Comments
 (0)