Overview
CrystalFormer is a unified autoregressive transformer model for inorganic crystalline material generation that supports both de novo crystal generation (DNG) and crystal structure prediction (CSP) within a single probabilistic framework. It is specifically designed for space group-controlled generation of crystalline materials. The space group symmetry significantly simplifies the crystal space, which is crucial for data and compute efficient generative modeling of crystalline materials.
The model can:
- De novo generation $p(C|\varnothing)$: Generate plausible crystal structures from scratch, without any formula constraint.
- Crystal structure prediction $p(C|f)$: Generate crystal structures conditioned on a given chemical formula $f$.
We provide a pretrained checkpoint to support both functionalities. No architectural change is required — CrystalFormer seamlessly switches behavior depending on whether a formula is supplied. Furthermore, the performance of both functionalities can be boosted via reinforcement fine-tuning.
Contents
- Contents
- Model Card
- Status
- Get Started
- Installation
- Available Weights
- Crystal Structure Prediction
- De Novo Generation
- Advanced Usage
- How to cite
Model Card
CrystalFormer is an autoregressive transformer for the probability distribution of crystal structures:
- De novo generation: $P(C|\varnothing) = P(g) P(W_1|...) P(A_1|...) P(X_1|...) ... P(L|...)$
- Formula-conditioned prediction: $P(C|f) = P(g|f) P(W_1|...) P(A_1|...) P(X_1|...) ... P(L|...)$
where the crystal structure $C$ is represented by the sequence $g-(W_{i}-A_{i}-X_{i})_{n}-L$:
- $f$: chemical formula, e.g. Cu<sub>12</sub>Sb<sub>4</sub>S<sub>13</sub>
- $g$: space group number 1-230
- $W$: Wyckoff letter ('a', 'b', ...,'A')
- $A$: atom type ('H', 'He', ..., 'Og') in the chemical formula
- $X$: fractional coordinates
- $L$: lattice vector [a, b, c, alpha, beta, gamma]
- $P(W_i| ...)$ and $P(A_i| ...)$ are categorical distributions.
- $P(X_i| ...)$ is the mixture of von Mises distribution.
- $P(L| ...)$ is the mixture of Gaussian distribution.
We only consider symmetry inequivalent atoms in the crystal representation. The remaining atoms are restored based on the information of space group and Wyckoff letters. There is a natural alphabetical ordering for the Wyckoff letters, starting with 'a' for a position with the site-symmetry group of maximal order and ending with the highest letter for the general position. The sampling procedure starts from higher symmetry sites (with smaller multiplicities) and then goes on to lower symmetry ones (with larger multiplicities). Only for the cases where the Wyckoff letter cannot fully determine the structure, one needs to further consider fractional coordinates in the loss or sampling.
Status
Major milestones are summarized below.
- v0.6: CrystalFormer for unified de novo generation and crystal structure prediction.
- v0.5: Initial release of CrystalFormer-CSP for crystal structure prediction.
- v0.4.2 : Add implementation of direct preference optimization.
- v0.4.1 : Replace the absolute positional embedding with the Rotary Positional Embedding (RoPE).
- v0.4 : Add reinforcement learning (proximal policy optimization).
- v0.3 : Add conditional generation in the plug-and-play manner.
- v0.2 : Add Markov chain Monte Carlo (MCMC) sampling for template-based structure generation.
- v0.1 : Initial implementations of crystalline material generation conditioned on the space group.
Get Started
Notebooks: The quickest way to get started with CrystalFormer is our notebooks in the Google Colab platform:
- ColabCSP
: Running CrystalFormer-CSP Seamlessly on Google Colab
- CrystalFormer-RL
: Reinforcement fine-tuning for materials design
- CrystalFormer Quickstart
: GUI notebook demonstrating the conditional generation of crystalline materials with CrystalFormer
- CrystalFormer Application
: Generating stable crystals with a given structure prototype. This workflow can be applied to tasks that are dominated by element substitution
Installation
Create a new environment and install the required packages, we recommend using python 3.10.* and conda to create the environment:
conda create -n crystalgpt python=3.10
conda activate crystalgpt
Before installing the required packages, you need to install jax and jaxlib first.
CPU installation
pip install -U "jax[cpu]"
CUDA (GPU) installation
If you intend to use CUDA (GPU) to speed up the training, it is important to install the appropriate version of jax and jaxlib. It is recommended to check the jax docs for the installation guide. The basic installation command is given below:
pip install --upgrade pip
# NVIDIA CUDA 12 installation
# Note: wheels only available on linux.
pip install -U "jax[cuda12]"
Install required packages and command line tools
After installing jax and jaxlib, you need to install the crystalformer package:
pip install .
During installation, the command line tools in the cli directory will be automatically installed.
Available Weights
We release the weights of the model trained on the Alex20s dataset. More details are available in the model card.
Crystal Structure Prediction
<div align="center"> <img align="middle" src="imgs/csp.png" width="500" alt="logo"/> <h2>Thinking fast and slow for crystal structure prediction</h2> </div>Sample
python ./main.py --optimizer none --restore_path RESTORE_PATH --K 40 --num_samples 1000 --formula Cu12Sb4S13 --save_path SAVE_PATH
optimizer: the optimizer to use,nonemeans no training, only samplingrestore_path: the path to the model weightsK: the top-K number of space groups will be sampled uniformly.num_samples: the number of samples to generateformula: the chemical formulasave_path: [Optional] the path to save the generated structures, if not provided, the structures will be saved in theRESTORE_PATHfolder.
Instead of providing K for top-K sampling, you may directly provide your favorite space group number
spacegroup: the space group number [1-230]
The sampled structure will be saved in the SAVE_PATH/output_Cu12Sb4S13.csv file. To transform the generated structure from g, W, A, X, L to the cif format, you can use the following command
python ./scripts/awl2struct.py --output_path SAVE_PATH --formula FORMULA
output_path: the path to read the generatedL, W, A, Xand save theciffilesformula: the chemical formula constrained in the structure
This will save the generated structures in the cif format to a output_Cu12Sb4S13_struct.csv file.
Relax generated structures with MLFF:
python scripts/mlff_relax.py \
--restore_path SAVE_PATH \
--filename output_Cu12Sb4S13_struct.csv \
--model orb-v3-conservative-inf-mpa \
--model_path path/to/orb-v3.ckpt \
--relaxation
This will produce relaxed structures in relaxed_structures with predicted energies.
Energy Above Hull (Ehull)
Compute Ehull for all relaxed structures:
python scripts/e_above_hull_alex.py \
--convex_path convex_hull_pbe.json.bz2 \
--restore_path SAVE_PATH \
--filename relaxed_structures.csv
End-to-End Pipeline
Run sampling → CIF conversion → relaxation → Ehull ranking:
./postprocess.sh \
-r RESTORE_PATH \
-k 40 \
--relaxation true \
-n 1000 \
-f Cu12Sb4S13 \
-s SAVE_PATH
In case you are curious about the parameters, run:
./postprocess.sh -h
Model Context Protocol (MCP) Server
CrystalFormer can be easily integrated with AI assistants via the Model Context Protocol (MCP). Please refer to the MCP README for detailed instructions on setting up and using the MCP server for crystal structure prediction.
De Novo Generation
Sample
python ./main.py --optimizer none --restore_path YOUR_MODEL_PATH --spacegroup 160 --num_samples 1000 --batchsize 1000 --temperature 1.0
optimizer: the optimizer to use,nonemeans no training, only samplingrestore_path: the path to the model weightsspacegroup: the space group number to samplenum_samples: the number of samples to generatebatchsize: the batch size for samplingtemperature: the temperature for sampling
The sampling results will be saved in the output_LABEL.csv file, where the LABEL is the space group number g specified in the command --spacegroup.
Evaluation
Before evaluating the generated structures, you need to transform the generated g, W, A, X, L to the cif format. You can use the following command to transform the generated structures to the cif format and save as the csv file:
python ./scripts/awl2struct.py --output_path YOUR_PATH --label SPACE_GROUP --num_io_process 40
output_path: the path to read the generatedL, W, A, Xand save theciffileslabel: the label to save theciffiles, which is the space group numbergnum_io_process: the number of processes
[!IMPORTANT] The following evaluation script requires the
SMACT,matminer, andmatbench-genmetricspackages. We recommend installing them in a separate environment to avoid conflicts with other packages.
Calculate the structure and composition validity of the generated structures:
python ./scripts/compute_metrics.py --root_path YOUR_PATH --filename YOUR_FILE --num_io_process 40
root_path: the path to the datasetfilename: the filename of the generated structuresnum_io_process: the number of processes
Calculate the novelty and uniqueness of the generated structures:
python ./scripts/compute_metrics_matbench.py --train_path TRAIN_PATH --test_path TEST_PATH --gen_path GEN_PATH --output_path OUTPUT_PATH --label SPACE_GROUP --num_io_process 40
train_path: the path to the training datasettest_path: the path to the test datasetgen_path: the path to the generated datasetoutput_path: the path to save the metrics resultslabel: the label to save the metrics results, which is the space group numbergnum_io_process: the number of processes
Note that the training, test, and generated datasets should contain the structures within the same space group g which is specified in the command --label.
More details about post-processing are available in the scripts folder.
Advanced usage
Reinforcement Fine-tuning
[!IMPORTANT] Before running the reinforcement fine-tuning, please make sure you have installed the corresponding machine learning force field model or property prediction model. The
mlff_modelandmlff_patharguments in the command line should be set according to the model you are using. Currently, we only support theorbmodel for the $E_{hull}$ reward.BatchRelaxeris also needed for batch structure relaxation during fine-tuning.
train_ppo --folder ./data/\
--restore_path YOUR_PATH\
--reward ehull\
--convex_path YOUR_PATH/convex_hull_pbe.json.bz2\
--mlff_model orb-v3-conservative-inf-mpa\
--mlff_path YOUR_PATH/orb-v3-conservative-inf-mpa-20250404.ckpt \
--lr 1e-05 \
--dropout_rate 0.0 \
--K 40 \
--batchsize 500 \
--formula LiPH2O4
where
folder: the folder to save the model and logsrestore_path: the path to the pre-trained model weightsreward: the reward function to use,ehullmeans the energy above the convex hullconvex_path: the path to the convex hull data, which is used to calculate the $E_{hull}$. Only used when the reward isehullmlff_model: the machine learning force field model to predict the total energy. We supportorbmodel for the $E_{hull}$ rewardmlff_path: the path to load the checkpoint of the machine learning force field model
Currently, CSP reinforcement fine-tuning only supports the ehull reward. For DNG reinforcement fine-tuning, simply omit the --formula argument.
Writing custom reward functions
Custom reward functions are implemented as Python factory functions that return a pair (reward_fn, batch_reward_fn). Follow the patterns in crystalformer/reinforce/reward.py to implement your own reward functions.
[!CAUTION] Reward direction: The reinforcement fine-tuning uses gradient ascent combined with reward inversion, which effectively minimizes the reward. If you want to minimize a property (e.g., energy, formation energy), return the positive value directly. If you want to maximize a property, return its negative value as the reward. For example, to minimize $E_{hull}$, the reward function simply returns
ehull.
Guidelines
- Signature:
reward_fn(x)accepts a single sample tuple(G, L, XYZ, A, W)and returns a scalar reward (float or numpy scalar). - Batch API:
batch_reward_fn(x)accepts a batchedx=(G,L,XYZ,A,W)(JAX arrays). It should convert inputs to CPU numpy arrays, compute per-sample rewards (e.g., by calling reward_fn or a vectorized routine), and return a jax.numpy array placed on the GPU (see examples below for device transfers using jax.device_put). - Structure conversion: use
get_atoms_from_GLXYZAW(G, L, XYZ, A, W)from crystalformer.reinforce.reward to convert the representation to ASE Atoms or a pymatgen Structure before calling property predictors or MLFFs. - Robustness: catch exceptions and return a sensible dummy or clipped reward for failed predictions to avoid crashing training.
- Performance: for heavy operations (relaxations, MLFF evaluations), prefer parallel/batched implementations where possible.
Minimal example
from crystalformer.reinforce import reward as reward_mod
from pymatgen.io.ase import AseAtomsAdaptor
import jax
import jax.numpy as jnp
import numpy as np
def make_custom_reward_fn(model, dummy=0.0):
ase_adaptor = AseAtomsAdaptor()
def reward_fn(x):
G, L, XYZ, A, W = x
try:
atoms = reward_mod.get_atoms_from_GLXYZAW(G, L, XYZ, A, W)
struct = ase_adaptor.get_structure(atoms)
val = model(struct) # compute property from structure
return float(val)
except Exception:
return float(dummy)
def batch_reward_fn(x):
# move data to CPU numpy for Python-side processing
x = jax.tree_util.tree_map(lambda _x: jax.device_put(_x, jax.devices('cpu')[0]), x)
# iterate over samples (or use parallel map) and collect rewards
output = [reward_fn(sample) for sample in zip(*x)]
# return jnp.array on GPU
return jax.device_put(jnp.array(output), jax.devices('gpu')[0]).block_until_ready()
return reward_fn, batch_reward_fn
See the concrete implementations in crystalformer/reinforce/reward.py for more complete patterns (device transfers, relaxation, parallelization and clipping of rewards).
Pretrain
python ./main.py --folder ./data/ --cfg_drop_prob 0.5 --train_path YOUR_PATH/alex20s/train.csv --valid_path YOUR_PATH/alex20s/val.csv
where
folder: the folder to save the model and logscfg_drop_prob: classifier-free guidance drop probability for formula conditioning. A value of1disables formula conditioning (DNG), while a value of0always enables formula conditioning (CSP)train_path: the path to the training datasetvalid_path: the path to the validation dataset
Test the prediction accuracy of space groups on the test dataset
python scripts/predict_g.py --restore_path YOUR_PATH --valid_path YOUR_PATH --Nf 5 --Kx 16 --Kl 4 --h0_size 256 --transformer_layers 16 --num_heads 8 --key_size 32 --model_size 256 --embed_size 256 --batchsize 1000
How to cite
@article{cao2024space,
title={Space Group Informed Transformer for Crystalline Materials Generation},
author={Zhendong Cao and Xiaoshan Luo and Jian Lv and Lei Wang},
year={2024},
eprint={2403.15734},
archivePrefix={arXiv},
primaryClass={cond-mat.mtrl-sci}
}
@article{cao2025crystalformerrl,
title={CrystalFormer-RL: Reinforcement Fine-Tuning for Materials Design},
author={Zhendong Cao and Lei Wang},
year={2025},
eprint={2504.02367},
archivePrefix={arXiv},
primaryClass={cond-mat.mtrl-sci},
url={https://arxiv.org/abs/2504.02367},
}
@misc{cao2025crystalformercsp,
title={CrystalFormer-CSP: Thinking Fast and Slow for Crystal Structure Prediction},
author={Zhendong Cao and Shigang Ou and Lei Wang},
year={2025},
eprint={2512.18251},
archivePrefix={arXiv},
primaryClass={cond-mat.mtrl-sci},
url={https://arxiv.org/abs/2512.18251},
}
Note: This project is unrelated to https://github.com/omron-sinicx/crystalformer with the same name.