Spaces:
Runtime error
Runtime error
fix: wandb logging with sync_tensorboard
Browse files
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -755,7 +755,8 @@ def main():
|
|
| 755 |
|
| 756 |
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
| 757 |
for k, v in unreplicate(train_metric).items():
|
| 758 |
-
wandb.log({
|
|
|
|
| 759 |
|
| 760 |
train_time += time.time() - train_start
|
| 761 |
|
|
|
|
| 755 |
|
| 756 |
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
| 757 |
for k, v in unreplicate(train_metric).items():
|
| 758 |
+
wandb.log({"train/step": global_step})
|
| 759 |
+
wandb.log({f"train/{k}": jax.device_get(v)})
|
| 760 |
|
| 761 |
train_time += time.time() - train_start
|
| 762 |
|