Exploring PlantCaduceus - Google Collab

Mar 23, 2025
   296

PlantCAD provides a streamlined approach to modeling and analyzing plant structures in a Google Colab environment. The PlantCaduceus language model offers powerful capabilities for predicting, visualizing, and interpreting biological sequences, showcasing how complex biological tasks become manageable with just a few lines of code.



URL: https://colab.research.google.com/drive/1iIJP-yxpyf_U8EtVPEjh2cX2low2GmFO?usp=sharing








PlantCAD GitHub Repository: https://github.com/kuleshov-group/PlantCaduceus


Python Code:


# -*- coding: utf-8 -*-
"""PlantCAD_Colab_Example.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1QW9Lgwra0vHQAOICE2hsIVcp6DKClyhO

## Basic examples with PlantCaduceus

### Setup environment
"""

!pip install mamba-ssm[causal-conv1d]

!nvidia-smi

from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer
import torch
import numpy as np
import pandas as pd

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "kuleshov-group/PlantCaduceus_l32"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name).to(device)
model.eval()

print(f"✅ Model loaded on {device}")

"""### Tokenize the sequence"""

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
sequence = "CTTAATTAATATTGCCTTTGTAATAACGCGCGAAACACAAATCTTCTCTGCCTAATGCAGTAGTCATGTGTTGACTCCTTCAAAATTTCCAAGAAGTTAGTGGCTGGTGTGTCATTGTCTTCATCTTTTTTTTTTTTTTTTTAAAAATTGAATGCGACATGTACTCCTCAACGTATAAGCTCAATGCTTGTTACTGAAACATCTCTTGTCTGATTTTTTCAGGCTAAGTCTTACAGAAAGTGATTGGGCACTTCAATGGCTTTCACAAATGAAAAAGATGGATCTAAGGGATTTGTGAAGAGAGTGGCTTCATCTTTCTCCATGAGGAAGAAGAAGAATGCAACAAGTGAACCCAAGTTGCTTCCAAGATCGAAATCAACAGGTTCTGCTAACTTTGAATCCATGAGGCTACCTGCAACGAAGAAGATTTCAGATGTCACAAACAAAACAAGGATCAAACCATTAGGTGGTGTAGCACCAGCACAACCAAGAAGGGAAAAGATCGATGATCG"
encoding = tokenizer.encode_plus(
            sequence,
            return_tensors="pt",
            return_attention_mask=False,
            return_token_type_ids=False
        )
input_ids = encoding["input_ids"].to(device)
input_ids.shape

"""### Embedding"""

with torch.inference_mode():
    outputs = model(input_ids=input_ids, output_hidden_states=True)
embeddings = outputs.hidden_states[-1]

print(embeddings.shape)

"""#### Averaging forward and reverse embeddings"""

embeddings = embeddings.to(torch.float32).cpu().numpy()

hidden_size = embeddings.shape[-1] // 2
forward = embeddings[..., 0:hidden_size]
reverse = embeddings[..., hidden_size:]
reverse = reverse[..., ::-1]
averaged_embeddings = (forward + reverse) / 2
print(averaged_embeddings.shape)

"""### Masked token prediction"""

pos = 255
sequence[pos]

input_ids[0, pos] = tokenizer.mask_token_id
with torch.inference_mode():
    outputs = model(input_ids=input_ids)

nucleotides = list('acgt')
logits = outputs.logits
logits = logits[:, pos, [tokenizer.get_vocab()[nc] for nc in nucleotides]]
probs = torch.nn.functional.softmax(logits.cpu(), dim=1).numpy()

probs

df = pd.DataFrame(dict(nucleotides = nucleotides, probs = probs[0]))

df