| from abc import abstractmethod | |
| from transformers import PretrainedConfig | |
| class MeasurementPredictorConfig(PretrainedConfig): | |
| def __init__( | |
| self, | |
| sensor_token=" omit", | |
| sensor_token_id=None, # 35991 | |
| n_sensors=3, | |
| use_aggregated=True, | |
| sensors_weight = 0.7, | |
| aggregate_weight=0.3, | |
| **kwargs | |
| ): | |
| self.sensor_token = sensor_token | |
| self.sensor_token_id = sensor_token_id | |
| self.n_sensors = n_sensors | |
| self.use_aggregated = use_aggregated | |
| self.sensors_weight = sensors_weight | |
| self.aggregate_weight = aggregate_weight | |
| super().__init__(**kwargs) | |
| self.emb_dim = self.get_emb_dim() | |
| def get_emb_dim(self): | |
| raise NotImplementedError |