ModernTrajectoryNet: Transaction Embedding Classifier

A state-of-the-art PyTorch embedding classifier trained with modern deep learning techniques for transaction categorization. The model learns to project transaction embeddings toward their target category embeddings through trajectory-based contrastive learning.

Model Architecture

ModernTrajectoryNet combines several modern architectural innovations:

Core Components

  1. RMSNorm (Root Mean Square Layer Normalization)

    • More stable and computationally efficient than LayerNorm
    • Used in LLaMA, PaLM, and Gopher
    • Provides consistent gradient flow through deep networks
  2. SwiGLU (Swish-Gated Linear Unit)

    • SOTA activation function for feed-forward networks
    • Outperforms GELU and ReLU in expressivity
    • Gate mechanism: (x * sigmoid(x)) * linear(x)
  3. SEBlock (Squeeze-and-Excitation)

    • Channel attention mechanism
    • Allows dynamic weighting of embedding dimensions
    • Context-aware feature recalibration
  4. ModernBlock (Pre-Norm Architecture)

    • RMSNorm β†’ SwiGLU β†’ SEBlock β†’ Residual Connection
    • Incorporates layer scaling and stochastic depth (DropPath)
    • Enables training of very deep networks

Configuration

  • Input dimension: 768 (embedding size)
  • Hidden layers: 12 transformer-style blocks
  • Expansion ratio: 4x hidden dimension in SwiGLU
  • Dropout: 0.1
  • Stochastic depth: Linear decay across layers (0.0 β†’ 0.1)

Training Objective: Hybrid Trajectory Learning

The model is trained with HybridTrajectoryLoss, combining two objectives:

1. Adaptive InfoNCE (Contrastive Component)

  • Learnable temperature parameter for dynamic scaling
  • Contrastive loss with label smoothing (0.1)
  • Ensures the model maps input embeddings close to their true target embedding
  • Equation: L_contrastive = CrossEntropy(logits / T, labels)

2. Monotonic Ranking (Trajectory Component)

  • Enforces monotonically increasing similarity through the transaction sequence
  • Each step in the trajectory should have higher similarity than the previous step
  • Final embedding must achieve high similarity (ideally 1.0) with target
  • Margin constraint: sim[i+1] > sim[i] + 0.01
  • Ensures the model learns the path to the target, not just the endpoint

Loss Formulation

Total Loss = InfoNCE Loss + Monotonicity Loss

Why Trajectory Learning?

  • Transactions often evolve gradually toward their correct category
  • Intermediate embeddings should show progression toward the target
  • This inductive bias improves generalization and interpretability

Training Details

  • Optimizer: AdamW with weight decay (1e-4)
  • Learning rate: Cosine annealing from 3e-4 to 1e-6
  • Batch size: 128
  • Gradient clipping: 1.0
  • Epochs: 50 with early stopping (patience=5)
  • EMA (Exponential Moving Average): Decay=0.99 for evaluation stability
  • Augmentation: Input masking (p=0.15) and Gaussian noise (std=0.01) during training
  • Mixed Precision: AMP enabled for faster training on CUDA

Performance Metrics

The model optimizes for:

  1. Last Similarity: Similarity of final embedding with target (Target: β‰ˆ1.0)
  2. Monotonicity Accuracy: % of transitions with strictly increasing similarity (Target: 100%)
  3. Contrastive Accuracy: Ability to distinguish true target from other targets in batch

How to Load

from safetensors.torch import load_file
import torch
from config import Config
from model import ModernTrajectoryNet

# Load weights
weights = load_file("model.safetensors")

# Instantiate model
config = Config()
model = ModernTrajectoryNet(config)
model.load_state_dict(weights)
model.eval()

# Use model
with torch.no_grad():
    input_embedding = torch.randn(1, 768)  # Your transaction embedding
    output_embedding = model(input_embedding)
    print(output_embedding.shape)  # [1, 768]

Usage Example

import torch
from torch.nn.functional import normalize

# Assuming you have transaction embeddings and category embeddings
transaction_emb = model(input_embedding)  # [B, 768]

# Compute similarity with category embeddings
category_embs = normalize(category_embeddings, p=2, dim=1)  # [N_cats, 768]
transaction_emb_norm = normalize(transaction_emb, p=2, dim=1)  # [B, 768]

similarities = torch.matmul(transaction_emb_norm, category_embs.t())  # [B, N_cats]
predicted_category = torch.argmax(similarities, dim=1)  # [B]

Intended Uses

  • Transaction categorization: Classify business transactions into merchant categories
  • Embedding refinement: Project raw transaction embeddings to discriminative space
  • Contrastive learning: Extract improved embeddings for downstream tasks
  • Research: Study trajectory-based learning for sequential decision problems

Limitations & Biases

  • Synthetic data: Trained on synthetic transaction strings generated from Foursquare Open-Source (FSQ OS) business names and categories using qwen2.5-4b-instruct LLM
  • FSQ OS biases: Inherits biases from the FSQ OS dataset (e.g., geographic coverage, business type distribution)
  • Generation artifacts: LLM-based synthetic data may not reflect real-world transaction diversity
  • Category coverage: Limited to categories present in FSQ OS (typically 200-500 merchant types)
  • Language: Trained on English transaction strings; may not generalize to other languages

Recommendation: Validate performance on your specific transaction domain before production deployment.

Dataset

  • Source: Foursquare Open-Source (FSQ OS) business names and categories
  • Processing: LLM-based synthetic transaction generation
  • Size: ~1M synthetic transaction embeddings
  • Train/Val split: 90% / 10%

See the dataset for more details.

Files in This Repository

  • model.safetensors: Model weights in HuggingFace SafeTensors format (160MB)
  • README.md: This file
  • LICENSE: Apache 2.0 license

License

Apache License 2.0. See LICENSE file for details.

Citation

If you use this model, please cite:

@software{transactionclassifier2024,
  title={TransactionClassifier: Embedding-based Transaction Categorization},
  author={HighkeyPrxneeth},
  year={2024},
  url={https://huggingface.co/HighkeyPrxneeth/ModernTrajectoryNet}
}

Contact & Support

For questions about the model architecture, training, or usage, feel free to reach out!

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train HighkeyPrxneeth/ModernTrajectoryNet