Spaces:
Runtime error
Runtime error
fix: model config
Browse files
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -82,7 +82,7 @@ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
|
| 82 |
|
| 83 |
# Model hyperparameters, for convenience
|
| 84 |
OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
|
| 85 |
-
OUTPUT_LENGTH = 256 +
|
| 86 |
BOS_TOKEN_ID = 16384
|
| 87 |
BASE_MODEL = 'facebook/bart-large-cnn'
|
| 88 |
|
|
@@ -425,6 +425,8 @@ def main():
|
|
| 425 |
config.bos_token_id = BOS_TOKEN_ID # should not be used
|
| 426 |
config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
|
| 427 |
config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
|
|
|
|
|
|
|
| 428 |
#config.min_length = data_args.max_target_length # Set only in decoder?
|
| 429 |
#config.max_length = data_args.max_target_length # Set only in decoder?
|
| 430 |
|
|
|
|
| 82 |
|
| 83 |
# Model hyperparameters, for convenience
|
| 84 |
OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
|
| 85 |
+
OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
|
| 86 |
BOS_TOKEN_ID = 16384
|
| 87 |
BASE_MODEL = 'facebook/bart-large-cnn'
|
| 88 |
|
|
|
|
| 425 |
config.bos_token_id = BOS_TOKEN_ID # should not be used
|
| 426 |
config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
|
| 427 |
config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
|
| 428 |
+
config.forced_bos_token_id = None # we don't need this token
|
| 429 |
+
config.forced_eos_token_id = None # we don't need this token
|
| 430 |
#config.min_length = data_args.max_target_length # Set only in decoder?
|
| 431 |
#config.max_length = data_args.max_target_length # Set only in decoder?
|
| 432 |
|