alvawalt commited on
Commit
2803b0b
·
verified ·
1 Parent(s): cff0942

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.json +12 -0
  2. model.py +230 -0
  3. model.safetensors +3 -0
config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "bulkformer",
3
+ "num_genes": 19357,
4
+ "dim": 320,
5
+ "gb_repeat": 1,
6
+ "bins": 10,
7
+ "bin_head": 8,
8
+ "full_head": 4,
9
+ "p_repeat": 2,
10
+ "training_epoch": 4,
11
+ "final_loss": 0.2695700265943767
12
+ }
model.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------
2
+ # CancerTranscriptome-Mini-48M
3
+ # Model: Lightweight adaptation of BulkFormer
4
+ # Author: Walter Alvarado (NASA Ames Research Center)
5
+ # License: MIT
6
+ #
7
+ # References:
8
+ # (1) Boming Kang, Rui Fan, Meizheng Yi, Chunmei Cui, Qinghua Cui.
9
+ # “A large-scale foundation model for bulk transcriptomes.”
10
+ # bioRxiv (2025). doi:10.1101/2025.06.11.659222
11
+ #
12
+ # (2) Alvarado W. “CancerTranscriptome-Mini-48M: A compact cancer-
13
+ # focused BulkFormer derivative.” https://github.com/alwalt/BioFM
14
+ #
15
+ # Data Source:
16
+ # ARCHS4 Human RNA-seq v2.5 (Lachmann et al., Nat Commun 2018)
17
+ # ------------------------------------------------------------
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch_geometric.nn.conv import GCNConv
22
+ from performer_pytorch import Performer
23
+
24
+ # Default model hyperparameters
25
+ model_params = {
26
+ "dim": 320,
27
+ "bins": 10,
28
+ "gb_repeat": 1,
29
+ "p_repeat": 2,
30
+ "bin_head": 8,
31
+ "full_head": 4,
32
+ "gene_length": 19357
33
+ }
34
+
35
+ # ------------------------------------------------------------
36
+ # Rotary Expression Embedding (REE)
37
+ # ------------------------------------------------------------
38
+
39
+ class PositionalExprEmbedding(nn.Module):
40
+ """
41
+ Rotary Expression Embedding (REE):
42
+ Converts continuous gene expression values into a sinusoidal
43
+ embedding usable by Performer/Transformer blocks. Deterministic,
44
+ not learned. Masked positions (-10) → zero vector.
45
+ """
46
+ def __init__(self, dim, mask_token=-10):
47
+ super().__init__()
48
+ self.mask_token = mask_token
49
+ self.inv_freq = nn.Parameter(
50
+ 1.0 / (100 ** (torch.arange(0, dim, 2).float() / dim)),
51
+ requires_grad=False
52
+ )
53
+
54
+ def forward(self, x):
55
+ mask = (x == self.mask_token).nonzero(as_tuple=False)
56
+ x = torch.einsum("bi,j->bij", x, self.inv_freq)
57
+ x = torch.cat([x.sin(), x.cos()], dim=-1)
58
+ x[mask[:, 0], mask[:, 1]] = 0
59
+ return x
60
+
61
+
62
+ # ------------------------------------------------------------
63
+ # GBFormer Block (Graph + Local Performer + Global Performer)
64
+ # ------------------------------------------------------------
65
+
66
+ class GBFormer(nn.Module):
67
+ """
68
+ A single GBFormer block:
69
+ - LayerNorm
70
+ - GCNConv (gene-gene propagation)
71
+ - Binning by learned importance score
72
+ - Local Performer per-bin
73
+ - Global Performer
74
+ """
75
+ def __init__(self, dim, gene_length, bin_head, full_head, bins, p_repeat):
76
+ super().__init__()
77
+
78
+ self.dim = dim
79
+ self.bins = bins
80
+ self.bin_head = bin_head
81
+ self.full_head = full_head
82
+ self.p_repeat = p_repeat
83
+
84
+ self.layernorm = nn.LayerNorm(dim)
85
+ self.gcn = GCNConv(dim, dim, cached=True, add_self_loops=False)
86
+
87
+ # Learn scoring → assign gene to bin
88
+ self.which_bin = nn.Linear(dim, 1)
89
+
90
+ # Local Performer per bin
91
+ self.bin_layers = nn.ModuleList([
92
+ Performer(
93
+ dim=dim,
94
+ heads=bin_head,
95
+ depth=1,
96
+ dim_head=dim // bin_head,
97
+ attn_dropout=0.2,
98
+ ff_dropout=0.2
99
+ )
100
+ for _ in range(bins)
101
+ ])
102
+
103
+ # Global Performer stack
104
+ self.global_layers = nn.Sequential(*[
105
+ Performer(
106
+ dim=dim,
107
+ heads=full_head,
108
+ depth=1,
109
+ dim_head=dim // full_head
110
+ )
111
+ for _ in range(p_repeat)
112
+ ])
113
+
114
+ def forward(self, x, graph):
115
+ B, G, D = x.shape
116
+
117
+ x = self.layernorm(x)
118
+ x = x + self.gcn(x, graph) # residual GCN update
119
+
120
+ if self.bins > 0:
121
+ scores = self.which_bin(x).squeeze(-1) # [B, G]
122
+ order = torch.argsort(scores, dim=1, descending=True)
123
+ order_full = order.unsqueeze(-1).expand(-1, -1, D)
124
+
125
+ x_sorted = x.gather(1, order_full)
126
+ bin_size = (G - 1) // self.bins + 1
127
+ chunks = torch.split(x_sorted, bin_size, dim=1)
128
+
129
+ processed = [
130
+ layer(chunk)
131
+ for chunk, layer in zip(chunks, self.bin_layers)
132
+ ]
133
+
134
+ x_cat = torch.cat(processed, dim=1)
135
+ x = torch.empty_like(x_cat).scatter_(1, order_full, x_cat)
136
+
137
+ x = self.global_layers(x)
138
+ return x
139
+
140
+
141
+ # ------------------------------------------------------------
142
+ # Full BulkFormer Model
143
+ # ------------------------------------------------------------
144
+
145
+ class BulkFormer(nn.Module):
146
+ """
147
+ CancerTranscriptome-Mini-48M:
148
+ A compact BulkFormer-style masked-expression model.
149
+ Combines:
150
+ - ESM2 gene identity embeddings
151
+ - Rotary Expression Embeddings (REE)
152
+ - Graph Convolution (GCNConv)
153
+ - Local/global Performer attention
154
+ - Optional intermediate repr_layers for feature extraction
155
+ """
156
+ def __init__(
157
+ self,
158
+ dim,
159
+ graph,
160
+ gene_emb,
161
+ gene_length,
162
+ bin_head=4,
163
+ full_head=4,
164
+ bins=10,
165
+ gb_repeat=1,
166
+ p_repeat=1
167
+ ):
168
+ super().__init__()
169
+
170
+ self.dim = dim
171
+ self.graph = graph
172
+ self.gene_length = gene_length
173
+
174
+ # Identity embeddings from ESM2 (trainable projection)
175
+ self.gene_emb = nn.Parameter(gene_emb)
176
+ self.gene_proj = nn.Sequential(
177
+ nn.Linear(gene_emb.shape[1], 4 * dim),
178
+ nn.ReLU(),
179
+ nn.Linear(4 * dim, dim)
180
+ )
181
+
182
+ # REE for expression
183
+ self.expr_emb = PositionalExprEmbedding(dim)
184
+
185
+ # Pre-attention mixing layer
186
+ self.mix = nn.Sequential(
187
+ nn.Linear(dim, 4 * dim),
188
+ nn.ReLU(),
189
+ nn.Linear(4 * dim, dim)
190
+ )
191
+
192
+ # Stacked GBFormer blocks
193
+ self.gb_blocks = nn.ModuleList([
194
+ GBFormer(dim, gene_length, bin_head, full_head, bins, p_repeat)
195
+ for _ in range(gb_repeat)
196
+ ])
197
+
198
+ self.final_norm = nn.LayerNorm(dim)
199
+
200
+ # Output head → scalar prediction per gene
201
+ self.head = nn.Sequential(
202
+ nn.Linear(dim, 4 * dim),
203
+ nn.ReLU(),
204
+ nn.Linear(4 * dim, 1),
205
+ nn.ReLU()
206
+ )
207
+
208
+ def forward(self, x, repr_layers=None):
209
+ B, G = x.shape
210
+ hidden = {}
211
+
212
+ x = (
213
+ self.expr_emb(x) +
214
+ self.gene_proj(self.gene_emb) +
215
+ torch.zeros(B, 1, self.dim, device=x.device) # no AE latent in this version
216
+ )
217
+
218
+ x = self.mix(x)
219
+
220
+ for i, block in enumerate(self.gb_blocks):
221
+ x = block(x, self.graph)
222
+ if repr_layers and i in repr_layers:
223
+ hidden[i] = x
224
+
225
+ x = self.final_norm(x)
226
+ out = self.head(x).squeeze(-1)
227
+
228
+ if repr_layers:
229
+ return out, hidden
230
+ return out
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06205e23331567f7c90e7338493fecef3e5d775349196966dbcda35175c5760b
3
+ size 253500552