def format_prompt(article, summary):
prompt_template = (
"<|begin_of_text|>"
"<|start_header_id|>system<|end_header_id|>\n\n"
"You are an expert summarizer. Your task is to provide a concise and accurate summary of the following news article. "
"The summary should capture the main points and key facts from the article, and should be no longer than 150 words."
"<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n"
"Article: {article}"
"<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
"{summary}<|eot_id|>")
return prompt_template.format(article=article, summary=summary)
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"def tokenize_function(example):
formatted_text = format_prompt(example["article"], example["highlights"])
tokenized_output = tokenizer(formatted_text, truncation=True, max_length=8192)
return tokenized_output
tokenized_datasets = raw_datasets.map(tokenize_function, batched=False, remove_columns=raw_datasets.column_names)training_args = TrainingArguments(
output_dir="./llama-summarization-ft",
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
optim="paged_adamw_8bit",
learning_rate=2e-4,
fp16=True,
max_steps=2000,
logging_steps=100,
save_steps=500,
eval_steps=500,
eval_strategy="steps",
lr_scheduler_type="cosine",
load_best_model_at_end=True)trainer = Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=lora_config,
dataset_text_field="text",
max_seq_length=8192,
tokenizer=tokenizer,
args=training_args,
packing=False,
formatting_func=lambda example: [format_prompt(example["article"], example["highlights"])])
trainer.train()base_model_for_inference = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True)
model_for_inference = PeftModel.from_pretrained(base_model_for_inference, output_dir)
tokenizer_for_inference = AutoTokenizer.from_pretrained(output_dir)