Commit
·
fcff61b
1
Parent(s):
eec3f65
Add bandaid for empty strings
Browse files
run_speech_recognition_seq2seq.py
CHANGED
|
@@ -46,7 +46,6 @@ from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
|
| 46 |
from transformers.utils import check_min_version
|
| 47 |
from transformers.utils.versions import require_version
|
| 48 |
|
| 49 |
-
|
| 50 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
| 51 |
check_min_version("4.17.0.dev0")
|
| 52 |
|
|
@@ -89,7 +88,7 @@ class ModelArguments:
|
|
| 89 |
default=False,
|
| 90 |
metadata={
|
| 91 |
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
| 92 |
-
|
| 93 |
},
|
| 94 |
)
|
| 95 |
freeze_feature_encoder: bool = field(
|
|
@@ -124,14 +123,14 @@ class DataTrainingArguments:
|
|
| 124 |
default=None,
|
| 125 |
metadata={
|
| 126 |
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 127 |
-
|
| 128 |
},
|
| 129 |
)
|
| 130 |
max_eval_samples: Optional[int] = field(
|
| 131 |
default=None,
|
| 132 |
metadata={
|
| 133 |
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
| 134 |
-
|
| 135 |
},
|
| 136 |
)
|
| 137 |
audio_column_name: str = field(
|
|
@@ -155,9 +154,9 @@ class DataTrainingArguments:
|
|
| 155 |
default=False,
|
| 156 |
metadata={
|
| 157 |
"help": "Whether to only do data preprocessing and skip training. "
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
},
|
| 162 |
)
|
| 163 |
train_split_name: str = field(
|
|
@@ -283,12 +282,14 @@ def main():
|
|
| 283 |
|
| 284 |
if training_args.do_train:
|
| 285 |
raw_datasets["train"] = load_dataset(
|
| 286 |
-
data_args.dataset_name, data_args.dataset_config_name, split=data_args.train_split_name,
|
|
|
|
| 287 |
)
|
| 288 |
|
| 289 |
if training_args.do_eval:
|
| 290 |
raw_datasets["eval"] = load_dataset(
|
| 291 |
-
data_args.dataset_name, data_args.dataset_config_name, split=data_args.eval_split_name,
|
|
|
|
| 292 |
)
|
| 293 |
|
| 294 |
if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
|
|
@@ -378,6 +379,8 @@ def main():
|
|
| 378 |
input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
|
| 379 |
|
| 380 |
input_str = re.sub(r"<\*?(ee|qq|mm|inaudible)>", "", input_str, re.IGNORECASE)
|
|
|
|
|
|
|
| 381 |
|
| 382 |
batch["labels"] = tokenizer(input_str).input_ids
|
| 383 |
return batch
|
|
|
|
| 46 |
from transformers.utils import check_min_version
|
| 47 |
from transformers.utils.versions import require_version
|
| 48 |
|
|
|
|
| 49 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
| 50 |
check_min_version("4.17.0.dev0")
|
| 51 |
|
|
|
|
| 88 |
default=False,
|
| 89 |
metadata={
|
| 90 |
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
| 91 |
+
"with private models)."
|
| 92 |
},
|
| 93 |
)
|
| 94 |
freeze_feature_encoder: bool = field(
|
|
|
|
| 123 |
default=None,
|
| 124 |
metadata={
|
| 125 |
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 126 |
+
"value if set."
|
| 127 |
},
|
| 128 |
)
|
| 129 |
max_eval_samples: Optional[int] = field(
|
| 130 |
default=None,
|
| 131 |
metadata={
|
| 132 |
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
| 133 |
+
"value if set."
|
| 134 |
},
|
| 135 |
)
|
| 136 |
audio_column_name: str = field(
|
|
|
|
| 154 |
default=False,
|
| 155 |
metadata={
|
| 156 |
"help": "Whether to only do data preprocessing and skip training. "
|
| 157 |
+
"This is especially useful when data preprocessing errors out in distributed training due to timeout. "
|
| 158 |
+
"In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
|
| 159 |
+
"so that the cached datasets can consequently be loaded in distributed training"
|
| 160 |
},
|
| 161 |
)
|
| 162 |
train_split_name: str = field(
|
|
|
|
| 282 |
|
| 283 |
if training_args.do_train:
|
| 284 |
raw_datasets["train"] = load_dataset(
|
| 285 |
+
data_args.dataset_name, data_args.dataset_config_name, split=data_args.train_split_name,
|
| 286 |
+
cache_dir=data_args.data_cache_dir
|
| 287 |
)
|
| 288 |
|
| 289 |
if training_args.do_eval:
|
| 290 |
raw_datasets["eval"] = load_dataset(
|
| 291 |
+
data_args.dataset_name, data_args.dataset_config_name, split=data_args.eval_split_name,
|
| 292 |
+
cache_dir=data_args.data_cache_dir
|
| 293 |
)
|
| 294 |
|
| 295 |
if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
|
|
|
|
| 379 |
input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
|
| 380 |
|
| 381 |
input_str = re.sub(r"<\*?(ee|qq|mm|inaudible)>", "", input_str, re.IGNORECASE)
|
| 382 |
+
if len(input_str) == 0:
|
| 383 |
+
input_str = "." # bandaid
|
| 384 |
|
| 385 |
batch["labels"] = tokenizer(input_str).input_ids
|
| 386 |
return batch
|