Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,41 @@ def __init__(self, config: Dict[str, Any], **kwargs: Any):
self.discriminator = ElectraForPreTraining(self.discriminator_config)
self.replace_p = 0.1

def _process_batch(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any]:
"""
Processes the batch data, cuts off x to max_position_embeddings

Args:
batch (XYData): The input batch of data.
batch_idx (int): The index of the current batch.

Returns:
Dict[str, Any]: Processed batch data.
"""

# cut off to max length of max_position_embeddings
x = batch.x[:, : self.generator_config.max_position_embeddings]

model_kwargs = batch.additional_fields["model_kwargs"]
if "mask" in model_kwargs:
try:
model_kwargs["mask"] = model_kwargs["mask"][
:, : self.generator_config.max_position_embeddings
]
except Exception as e:
print(
f"Failed to cut off mask {model_kwargs['mask'].shape} to max_position_embeddings: {e}"
)
raise e

return dict(
features=x,
labels=self._process_labels_in_batch(batch),
model_kwargs=model_kwargs,
loss_kwargs=batch.additional_fields["loss_kwargs"],
idents=batch.additional_fields["idents"],
)

@property
def as_pretrained(self) -> ElectraForPreTraining:
"""
Expand Down Expand Up @@ -204,8 +239,13 @@ def _process_batch(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any
* CLS_TOKEN
)
model_kwargs["output_attentions"] = True

x = torch.cat((cls_tokens, batch.x), dim=1)
# cut off to max length of max_position_embeddings
x = x[:, : self.config.max_position_embeddings]

return dict(
features=torch.cat((cls_tokens, batch.x), dim=1),
features=x,
labels=batch.y,
model_kwargs=model_kwargs,
loss_kwargs=loss_kwargs,
Expand Down
27 changes: 24 additions & 3 deletions chebai/preprocessing/datasets/pubchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import shutil
import tempfile
from datetime import datetime
from typing import Generator, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union

import pandas as pd
import requests
Expand Down Expand Up @@ -284,10 +284,10 @@ def __init__(self, train_batch_size=1_000_000, *args, **kwargs):
self.test_batch_size = 100_000

@property
def processed_file_names_dict(self) -> List[str]:
def processed_file_names_dict(self) -> Dict[str, str]:
"""
Returns:
List[str]: List of processed data file names.
Dict[str, str]: Dictionary of processed data file names.
"""
train_samples = (
self._n_samples
Expand Down Expand Up @@ -404,6 +404,27 @@ def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader
**kwargs,
)

def load_processed_data(
self, kind: Optional[str] = None, filename: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Loads processed data from a specified dataset type or file. Loads data directly from file instead of
using the dynamic_splits_df property. This ensures that a new training batch is loaded for each epoch.
"""
if kind is None and filename is None:
raise ValueError(
"Either kind or filename is required to load the correct dataset, both are None"
)

# If both kind and filename are given, use filename
if kind is not None and filename is None:
return self.load_processed_data_from_file(
self.processed_file_names_dict[kind]
)

# If filename is provided
return self.load_processed_data_from_file(filename)


class LabeledUnlabeledMixed(XYBaseDataModule):
"""
Expand Down
4 changes: 2 additions & 2 deletions configs/model/electra-for-pretraining.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ init_args:
lr: 1e-3
config:
generator:
vocab_size: 4400
vocab_size: 600
max_position_embeddings: 1800
num_attention_heads: 8
num_hidden_layers: 6
type_vocab_size: 1
discriminator:
vocab_size: 4400
vocab_size: 600
max_position_embeddings: 1800
num_attention_heads: 8
num_hidden_layers: 6
Expand Down
2 changes: 1 addition & 1 deletion configs/model/electra300.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ init_args:
optimizer_kwargs:
lr: 1e-3
config:
vocab_size: 4400
vocab_size: 600
max_position_embeddings: 301
num_attention_heads: 8
num_hidden_layers: 6
Expand Down
Loading