I decided to try out a transformer based architecture after my initial explorations into world modelling with a U-Net based diffusion approach. The new architecture combined a vision transformer-based masked autoencoder (MAE) with a transformer-based predictor. The predictor interleaved encoded frames with encoded actions to predict the next frame in latent space, which the MAE decoder then reconstructed. You can see an initial version of this here. I spent a few months trying to get something working on the data from my JetBot, but didn’t see much progress. Despite receiving a sequence of frames and actions, the model ignored the action data; it simply used the last frame as its prediction for the next. My log up to Day 43 covers what I tried (more recently I found some evidence the JetBot data quality might not be very good, but that’s a different story). I decided to simplify the problem and ended up with a dot on a screen that would either move to the right or not based on a binary action. The dot would wrap around to the other side of the screen as it continued so I called this the toroidal dot world.

I started by collecting and training on data of a dot at a fixed y position moving across the screen. The model struggled, consistently predicting either a black frame or simply repeating the previous frame, depending on whether I prioritized latent feature loss or pixel-level loss. The predictor wasn’t really working, but the autoencoder seemed to do a good job at reconstructing images so I thought I’d try recasting the problem to just use an autoencoder. Rather than treating prediction as a sequence of discrete images and actions, I combined everything into a single image—a ‘canvas.’ In this format, frames are concatenated side-by-side, separated by colored bars that encode the preceding action.
This reframed prediction as an inpainting task: masking the last frame in the canvas and asking the model to fill it in. This allowed me to simplify the model to just the MAE, adapted to process canvases (code here).
This approach didn’t work initially. The inpainting of the last part of the canvas would be mostly black. After experimenting with a loss function that combined standard mean squared error (MSE) and a focal weighted MSE I began to see some progress. The focal weighted MSE applies exponential weighting to errors, placing extra emphasis on areas where the prediction error is high. This hybrid loss started giving predictions of multiple dots roughly where the true dot might be.
After tuning hyperparameters and speeding up the training pipeline so I could iterate faster, I trained a model that makes accurate action-conditioned predictions. I then made the toroidal world slightly harder by randomizing the dot’s starting y-position; the model still performed well on a held-out validation set (Days 66-73).
I’m now applying the same idea to real robot data. I’m starting with the LeRobot SO-101 arm, focusing on learning an action-conditioned model of a single joint before scaling up to more complex motion.