Building World Model from Scratch

TL;DR I trained a world model from scratch and it kinda works. github repo

Intro

During my PhD, I was obsessed with a single question: how do we build machines that think like humans? For several reasons, I believed program synthesis would be a core ingredient for that path. I spent a long time exploring LLMs and program induction.

More recently, the idea of a world model grabbed my attention. Neuroscience has tons of evidence that our brains make decisions by simulating the future. World models have started to crack problems in real-world robotics. And Genie 3 looked like it might even make video game production obsolete.

At first, I started a project that modeled world dynamics using program synthesis, the tool I was most comfortable with. But the further I went, I had to admit that this approach was unlikely to solve real-world robotics in any meaningful way.

I don’t mean it can’t solve any real-world problems. When a domain’s inputs and outputs already have strong structure, program synthesis can excel. A large fraction of knowledge work that a person can do sitting at a computer will likely be solvable via a combination of LLMs + program synthesis. But real-world robotics is simply too noisy to be handled by a system where programs are first-class citizens (although programs may generate auxiliary training data for neural models).

From Symbolic to Neural

When I realized this, I turned my attention to neural world models and came across Dreamer 4. Dreamer 4 is a world model that simulates the Minecraft environment and trains a policy inside its “imagination.” It’s the first model to successfully mine diamonds in Minecraft while being trained purely from offline human gameplay logs.

What does that imply? A fundamental reason robot learning is hard is that you can’t really do trial-and-error learning in the real world. Errors often mean the robot breaks something. If we can get meaningful performance from offline data alone, that suggests a similar path for robotics: we might build robots that work reliably in the real world using only teleoperation data or human egocentric data.

So I began implementing Dreamer 4’s world model (unfortunately, the paper’s code was not released), and if possible, trying to improve it. This post is both a record of that journey and an educational resource for beginners getting into world models.

I started from the codebase at github.com/p-doom/jasmine, so I used JAX and flax.nnx. I also referenced many parts of github.com/edwhu/dreamer4-jax.

What Is a World Model?

When you move a coffee cup on your table, you don’t need to physically test whether it will pass through the table. Your brain runs a simulation that predicts what will happen next. That internal predictive simulator is essentially a world model.

In AI, we train a neural network to do the same thing: given the current observation $o_t$ and an action $a_t$, predict the next observation $o_{t+1}$. In practice (especially for high-dimensional inputs like video), we often predict in a compact latent space and decode back to pixels.

Once you have such a model, an agent can plan inside it or be trained inside it, rolling out hypothetical action sequences and observations without interacting with the real environment. This is especially valuable in robotics, where trial-and-error in the real world is prohibitively expensive.

Architecture

Dreamer 4 consists of two components: a video tokenizer and a dynamics model.

The video tokenizer maps high-dimensional video into low-dimensional latent tokens. The dynamics model predicts the future in this latent space, substantially cutting down compute per timestep.

The video tokenizer is trained with a masked autoencoding objective. It's a pretty standard objective in tokenization — it randomly masks a certain portion of the input tokens and reconstructs the masked parts, based on neighboring and past tokens. One unusual aspect of the Dreamer 4 tokenizer is how it extracts latents. Instead of directly mapping an image into tokens, it concatenates special tokens to the input, feeds everything through the encoder, and then reads out the outputs corresponding to those special tokens as the latent representation.

The bottleneck compression ratio is determined by (1) how many special tokens you add, and (2) the output dimension of the final projection layer. For example, if you use 32 special tokens, you get 32 latent tokens. If the final projection outputs dimension 64, then a single frame is represented as a 32×64 matrix. If the video length is 64 frames, the full latent tensor becomes 64×32×64.

This design gives fine control over compute in the dynamics model. If you increase the number of latent tokens, the dynamics model must process more tokens per time step. This is not merely tuning bottleneck size (you can also adjust latent dimension); the key benefit is that you can tune bottleneck capacity and dynamics compute intensity almost independently.

The dynamics model is trained using a technique called shortcut forcing. Shortcut forcing roughly means training a shortcut model in a diffusion forcing style.

A shortcut model is a kind of flow-based model, with the benefit of faster inference than conventional flow models. At a high level, it learns a coarse step whose target is the composition of two fine steps.

Diffusion forcing trains the model to denoise after injecting random noise into each frame during training. At inference time, you provide a clean context and add noise to the last frame; denoising then generates video autoregressively.

Many additional techniques are used for speed and performance: time-space axial transformers reduce attention cost, and LLM techniques like RMSNorm and RoPE are heavily used. There is also ramp loss, which down-weights loss on samples with low signal level. For full details, refer to the Dreamer 4 paper: arXiv:2509.24527.

Training

I had access to 4× A100 GPUs, and even that was not always available due to queueing. I couldn’t exhaustively experiment with every design choice; I needed to train as efficiently as possible from the start.

In my view, Dreamer 4’s biggest advantage was training with a long context length. Compared to MineWorld or Oasis—using 0.8s / 1.6s contexts (16 / 32 frames)—Dreamer 4’s 9.6s context (192 frames) is enormous, and that helps a lot for long-horizon simulation.

Dreamer 4 could do this with 1024 TPUs, but I had to compromise. In Minecraft, I set the maximum context length to 64 frames. I also mixed short batches (16 frames) with long batches during training, as Dreamer 4 does.

For CALVIN, since I used a smaller model, I used a maximum length of 96 frames, and again mixed it with shorter 24-frame batches.

CoinRun training log
CoinRun training loss curves.
CALVIN training log
CALVIN training loss curves.

Inference

After training the model and entering the first validation loop, I found validation was taking an abnormally long time. Autoregressive generation across multiple frames is naturally slow, but validation was taking one-third of total training time. The issue turned out to be complicated.

First, I started without a KV cache implementation, so the clean context had to be recomputed from scratch on every forward pass. I initially thought this was the sole cause.

After implementing KV cache as an internal model state and passing sanity checks, I ran validation again—and was shocked to see validation time did not decrease at all. After spending hours looking for a non-existent bug, I realized the issue wasn’t a bug; it was how JAX works.

Because JAX compiles computation graphs, changing input tensor shapes triggers recompilation, which is very slow. As KV cache accumulated, the sequence length increased step by step, and recompilation was happening every time the length grew by one.

Looking into how this is typically handled in JAX/Flax, I found that Flax initializes the cache at a fixed size and updates elements in place.

But the MultiHeadAttention class that included this feature wrapped scaled dot product attention (SDPA), which made it incompatible with RoPE, where you need to directly modify queries and keys. In the end, I implemented a new attention class that supports both RoPE and KV cache.

This new class also changed cache-index handling to better match diffusion/flow inference: instead of automatically advancing the cache index on every forward pass, it advances only when denoising finishes (manually controlled).

With this new class, inference became dramatically faster, and validation became negligible in total training time. Still, a single training run takes 3–5 days in Minecraft. The waiting is the hardest part.

Dataset

Because jasmine supports fast prototyping on Procgen’s coinrun, it was the first dataset I tried. The tokenizer and dynamics model were learning reasonably well.

A world model trained on randomly generated coinrun trajectories. Left: real environment frames. Right: imagination by the model given only the first 8 frames (0.8s).

So I moved to Minecraft. For training data, I used the OpenAI VPT contractor demonstrations, which was the largest dataset I had ever worked with. The video and action files were about 7TB, too large to fit on my server, so I had to resize the original 360p videos down to 224p.

After starting tokenizer training, I noticed GPU utilization was abnormally low. After investigating, I found the issue: when sampling a batch, I was reading entire ~5-minute videos each time.

For example, if the batch size is 32 and seq_len is 64, and videos are stored in 5-minute chunks, preparing a single batch requires reading 32 × 5 minutes = 160 minutes of video. At that point, hard disk I/O becomes slower than the GPU step time. So the dataset needed to be preprocessed into 3.2-second chunks (64 frames). In the end, I prepared 3 million 3.2-second video clips for training.

This video chunking process alone took forever, let alone training the tokenizer and dynamics model. So I ran experiments on the CALVIN first, which involves controlling a robot arm. CALVIN uses 200×200 images as observations and has a 7D continuous action space.

Results

On CALVIN, you can see that inside the world model’s imagination, the robot arm behaves quite similarly to the ground truth (GT). The causal mechanisms of the desk (turning the light on with the slider) are pretty well modeled too.

A world model trained on human demonstrations. Left: real environment frames. Right: imagination by the model given only the initial frame.

However, object consistency is not well maintained (see the purple block in the second video changing its size). My guess is that the latent compression ratio is too aggressive, making it difficult for the flow model to generate dynamics precisely. Even if tokenizer reconstruction is perfect, whether the latent space is easy to learn is a different story.

With the current compression ratio, even a small movement in latent space can map to a different representation, so small accumulated generation errors can turn into large shifts over time.

And finally, the Minecraft dynamics model is learning something! It still has a few days of training left to fully converge, but you can already see it has some grasp of the Minecraft environment (the crafting table interface, placing blocks, etc.).

A world model trained on Minecraft. Left: real environment frames. Right: imagination by the model given only the first 8 frames (0.4s). After few steps, the frames become very blurry.

Next Steps

First, to address object persistency, I plan to retrain with a larger tokenizer bottleneck, especially by increasing the number of latent tokens. Of course, that means even longer training time.

According to the MeanFlow paper, it achieves better performance with far less training compute than shortcut models. I’m considering switching the flow model loss to a mean-flow objective to reduce the painfully long training time. Some people in the community say MeanFlow can be unstable in practice and may not perform well in difficult domains. But it’s worth trying.

Another promising direction is improving long-term memory. But rather than improving the world model’s long-term consistency alone, the more important question is: when paired with a policy, how can memory be used effectively to solve tasks?

References

  1. Hafner, D., Yan, W., & Lillicrap, T. (2025). Training Agents Inside of Scalable World Models. arXiv:2509.24527.
  2. Google DeepMind. (2025). Genie 3. Website.
  3. Guo, J. et al. (2025). MineWorld: a Real-Time and Open-Source Interactive World Model on Minecraft. arXiv:2504.08388.
  4. Decart & Etched. (2024). Oasis: A Universe in a Transformer. oasis-model.github.io.
  5. Frans, K., Hafner, D., Levine, S., & Abbeel, P. (2025). One Step Diffusion via Shortcut Models. ICLR 2025. arXiv:2410.12557.
  6. Chen, B. et al. (2024). Diffusion Forcing: Next-Token Prediction Meets Full-Sequence Diffusion. NeurIPS 2024. arXiv:2407.01392.
  7. Geng, Z. et al. (2025). Mean Flows for One-step Generative Modeling. NeurIPS 2025. arXiv:2505.13447.
  8. Baker, B. et al. (2022). Video PreTraining (VPT): Learning to Act by Watching Unlabeled Online Videos. arXiv:2206.11795.
  9. Mees, O., Hermann, L., Rosete-Beas, E., & Burgard, W. (2022). CALVIN: A Benchmark for Language-Conditioned Policy Learning for Long-Horizon Robot Manipulation Tasks. IEEE RA-L. arXiv:2112.03227.

© Kang-il Lee — Blog · Home