Exploring PlantCaduceus - Google Colab Updated (August 2025)

Mar 23, 2025
   467

PlantCAD provides a streamlined approach to modeling and analyzing plant structures in a Google Colab environment. The PlantCaduceus language model offers powerful capabilities for predicting, visualizing, and interpreting biological sequences, showcasing how complex biological tasks become manageable with just a few lines of code.



URL (Updated - August 2025): https://colab.research.google.com/drive/1sBhpR2_Hs0SbsqFXB6IWJ1kErj4Szlpf?usp=sharing







PlantCAD GitHub Repository: https://github.com/kuleshov-group/PlantCaduceus


Python Code:


# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
"""PlantCAD_Colab_Example.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1QW9Lgwra0vHQAOICE2hsIVcp6DKClyhO

## Basic examples with PlantCaduceus

### Setup environment
"""

!nvcc --version

"""## Install specific pytorch
We are installing a specific version of PyTorch (2.3.1) to resolve some errors when loading mamba-ssm, The most critical part is `--index-url https://download.pytorch.org/whl/cu121`. This command tells pip to NOT use the default repository. Instead, it downloads a special version of PyTorch that was pre-compiled specifically for systems with a CUDA 12.1 driver.

#### Why is this necessary?
Google Colab has its own system-level NVIDIA driver (e.g., CUDA 12.5). By installing a PyTorch build that is aware of a compatible CUDA version (12.1 works perfectly with a 12.5 driver), we ensure that all parts of the library, including low-level components like Triton, can find the correct drivers and compile code successfully during model inference. This fixes both the `libcuda.so` and the `ModuleNotFoundError` errors.
"""

!pip3 install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121

!pip3 install mamba-ssm==2.2.2 transformers==4.40.0 git+https://github.com/dridk/PyVCF3.git@master scipy==1.12.0 biopython xgboost==2.0.3 scikit-learn==1.4.0 matplotlib

"""🔄 You may need to restart the session after running the above setup cells.

Once restarted, you’re all set to start exploring the PlantCAD model!

## Testing if mamba-ssm is installed successfully
"""

# Test core dependencies
import torch
from mamba_ssm import Mamba
from transformers import AutoTokenizer, AutoModelForMaskedLM
import pandas as pd

device = 'cuda:0'

# Test PlantCAD model loading
tokenizer = AutoTokenizer.from_pretrained('kuleshov-group/PlantCaduceus_l32')
model = AutoModelForMaskedLM.from_pretrained('kuleshov-group/PlantCaduceus_l32', trust_remote_code=True)
model.to(device)
print("✅ Installation successful!")

"""## Play around with the model inputs and outputs"""

# Example plant DNA sequence (512bp max)
sequence = "CTTAATTAATATTGCCTTTGTAATAACGCGCGAAACACAAATCTTCTCTGCCTAATGCAGTAGTCATGTGTTGACTCCTTCAAAATTTCCAAGAAGTTAGTGGCTGGTGTGTCATTGTCTTCATCTTTTTTTTTTTTTTTTTAAAAATTGAATGCGACATGTACTCCTCAACGTATAAGCTCAATGCTTGTTACTGAAACATCTCTTGTCTGATTTTTTCAGGCTAAGTCTTACAGAAAGTGATTGGGCACTTCAATGGCTTTCACAAATGAAAAAGATGGATCTAAGGGATTTGTGAAGAGAGTGGCTTCATCTTTCTCCATGAGGAAGAAGAAGAATGCAACAAGTGAACCCAAGTTGCTTCCAAGATCGAAATCAACAGGTTCTGCTAACTTTGAATCCATGAGGCTACCTGCAACGAAGAAGATTTCAGATGTCACAAACAAAACAAGGATCAAACCATTAGGTGGTGTAGCACCAGCACAACCAAGAAGGGAAAAGATCGATGATCG"
device = 'cuda:0'
# Get embeddings
encoding = tokenizer.encode_plus(
            sequence,
            return_tensors="pt",
            return_attention_mask=False,
            return_token_type_ids=False
        )

input_ids = encoding["input_ids"].to(device)
with torch.inference_mode():
    outputs = model(input_ids=input_ids, output_hidden_states=True)

embeddings = outputs.hidden_states[-1]
print(f"Embedding shape: {embeddings.shape}")  # [batch_size, seq_len, embedding_dim]

"""## Averaging forward and reverse embeddings

Given that PlantCaduceus has bi-directionality and reverse complement equivariance, so the first half of embedding is for forward sequences and the sencond half is for reverse complemented sequences, we need to average the embeddings before working on downstream classifier
"""

embeddings = embeddings.to(torch.float32).cpu().numpy()

hidden_size = embeddings.shape[-1] // 2
forward = embeddings[..., 0:hidden_size]
reverse = embeddings[..., hidden_size:]
reverse = reverse[..., ::-1]
averaged_embeddings = (forward + reverse) / 2
print(averaged_embeddings.shape)

averaged_embeddings

"""### Masked token prediction"""

pos = 255
sequence[pos] # the true base of this position is A

input_ids[0, pos] = tokenizer.mask_token_id
with torch.inference_mode():
    outputs = model(input_ids=input_ids)

nucleotides = list('acgt')
logits = outputs.logits
logits = logits[:, pos, [tokenizer.get_vocab()[nc] for nc in nucleotides]]
probs = torch.nn.functional.softmax(logits.cpu(), dim=1).numpy()

probs

df = pd.DataFrame(dict(nucleotides = nucleotides, probs = probs[0]))

"""### The base A got the highest probability ✌🏻"""

df

"""## Run some examples on zero-shot mutation effect prediction"""

# clone PlantCAD repo
!git clone https://github.com/kuleshov-group/PlantCaduceus.git

!wget https://download.maizegdb.org/Zm-B73-REFERENCE-NAM-5.0/Zm-B73-REFERENCE-NAM-5.0.fa.gz
!gunzip Zm-B73-REFERENCE-NAM-5.0.fa.gz

!python PlantCaduceus/src/zero_shot_score.py \
    -input-vcf PlantCaduceus/examples/example_maize_snp.vcf \
    -input-fasta Zm-B73-REFERENCE-NAM-5.0.fa \
    -output scored_variants.vcf \
    -model 'kuleshov-group/PlantCaduceus_l32' \
    -device 'cuda:0'

"""🎉 Nice work! The variants have been scored, and the results are now embedded in the VCF’s INFO field."""

!grep -v '#' scored_variants.vcf | awk  -v OFS='\t' '{print $1,$2,$4,$5,$8}' | head -n 10