Continual Learning

Dataset

The full dataset can be downloaded from here: http://clevrer.csail.mit.edu.

VAE

Instead of working with high-resolution pixel space, I trained VAE to represent the frames in a lower dimensional latent space and make the training faster, more efficient and cheaper. After training it for ~300 epochs, the reconstructed frames from the test set latents are shown below. The ones on the left are ground truth test frames and the ones on the right are reconstructed by the decoder.

VAE Reconstruction

Optical Flow Model

To obtain ground truth optical flows for supervising the flow head, we can use a pre-trained optical flow model such as RAFT[1], FlowFormer[2], SEA-RAFT[3] or WAFT[4] and precompute the flow fields between each consecutive frame offline.

In this system, I am using RAFT.

Test-Time Training (TTT)

Assuming we choose approach 3, we can utilize the flow predictor head during test-time training. After training the 3 heads jointly, we can continue updating the flow predictor for each video during inference and reset it back to its pre-trained version for each new video.

Loss Functions

VAE

Temporal Model

Test-Time Training

Evaluation Metrics

Training

Phase 1: Train 2D VAE (~50-200 epochs)
Phase 2: Train temporal (~50-100 epochs)
Phase 3: Joint fine-tune (~20-50 epochs)

Precompute Optical Flows

./run_precompute.sh

Train 2D VAE

./run_train_vae.sh

Train Temporal Model

./run_train_temporal

Joint Fine-Tune

./run_joint

Inference

python inference.py \
    --vae_checkpoint checkpoints/vae_epoch0200.pt \
    --temporal_checkpoint checkpoints/temporal_epoch0005.pt \
    --video_folders video_test/video_15000-16000 \
    --img_h 128 --img_w 128 \
    --num_input_frames 20 \
    --num_pred_frames 12 \
    --ttt --ttt_steps 10 \
    --output_dir outputs \
    --device cuda

Notes

References