!pip install peft transformers datasets accelerate -q

from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    TrainingArguments, 
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset
import warnings
warnings.filterwarnings('ignore')

# Load base GPT-2 model and tokenizer
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)

# Apply LoRA
lora_cfg = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["c_attn", "c_proj"]
)
model = get_peft_model(model, lora_cfg)
print("Trainable parameters:")
model.print_trainable_parameters()

# Load and tokenize dataset
data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")

def tokenize(ex):
    texts = [t for t in ex["text"] if t and len(t.strip()) > 0]
    return tokenizer(texts, truncation=True, max_length=128)

data = data.map(tokenize, batched=True, remove_columns=data.column_names)

# Use data collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Training setup
args = TrainingArguments(
    output_dir="lora-gpt2",
    per_device_train_batch_size=4,
    num_train_epochs=1,
    learning_rate=3e-4,
    logging_steps=10,
    save_strategy="no",
    warmup_steps=50,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=data,
    data_collator=data_collator
)

print("\nStarting training...")
trainer.train()

# Generate text
print("\nGenerating text...")
model.eval()
prompt = "Artificial intelligence will"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

outputs = model.generate(
    **inputs,
    max_new_tokens=40,
    temperature=0.8,
    do_sample=True,
    top_p=0.9
)

print("\nGenerated text:")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))