RB.
Cart 0

Patch Mpt -

# Test attention mask expansion mask_2d = torch.tensor([[0, 0, 1, 1]]) # batch=1, key_len=4 expanded = patch_attention_mask(mask_2d, query_len=3, key_len=4, dtype=torch.float32) print(f"Expanded mask shape: expanded.shape") # (1,1,3,4) print(expanded) | Issue | Before patch | After patch | |-------|--------------|--------------| | Rotary cache | Recomputes every call, wastes memory | Only recomputes when seqlen changes | | Mask expansion | Only supports 2D masks | Supports 2D/3D/4D, correct broadcast | | Cross-attention | Mask shape mismatch | Proper (batch,1,q_len,k_len) | If you meant a firmware patch for an MPT controller (like in automotive or industrial PLCs), I can write a .bin patching script in Python or C. Just clarify the target.

# If already 4D, assume correct if attention_mask.dim() == 4: return attention_mask.to(dtype)

# Monkey-patch attention mask expansion function if model has it if hasattr(model, "_expand_attention_mask"): model._expand_attention_mask = patch_attention_mask print("[PATCH] Replaced _expand_attention_mask") Usage example ---------------------------------------------------------------------- if name == " main ": # Assume you have an MPT model loaded # from transformers import AutoModel # model = AutoModel.from_pretrained("mosaicml/mpt-7b", trust_remote_code=True) # apply_mpt_patches(model) patch mpt

def _update_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype): if seq_len == self._cached_seq_len: return inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self._cached_cos = emb.cos().to(dtype) self._cached_sin = emb.sin().to(dtype) self._cached_seq_len = seq_len

# Case: (batch, 1, key_len) elif attention_mask.dim() == 3 and attention_mask.size(1) == 1: mask = attention_mask[:, :, None, :] else: raise ValueError(f"Unexpected mask shape: attention_mask.shape") # Test attention mask expansion mask_2d = torch

class PatchedRotaryEmbedding(nn.Module): """Rotary embedding with cache reset on seqlen change.""" def (self, dim: int, max_seq_len: int = 2048, base: int = 10000): super(). init () self.dim = dim self.max_seq_len = max_seq_len self.base = base self._cached_cos = None self._cached_sin = None self._cached_seq_len = None

# patches/mpt_patch_rotary_cache.py """ Patch for MPT model: - Fix rotary embedding cache when sequence length changes between forward passes. - Correct attention mask broadcasting for cross-attention layers. """ import torch import torch.nn as nn from typing import Optional, Tuple 1. Patch Rotary Embedding Cache ---------------------------------------------------------------------- def patched_rotate_half(x: torch.Tensor) -> torch.Tensor: """Split and rotate half the hidden dims (fixed for fp16 stability).""" x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) init () self

# Case: (batch, key_len) -> expand to (batch, 1, 1, key_len) if attention_mask.dim() == 2: mask = attention_mask[:, None, None, :]