""" Run Gemini 2.5 Pro on pre-processed AML benchmark prompts. Designed for enterprise Vertex AI environments where HuggingFace/GitHub are blocked. Usage: # Run all JSONL files in a directory python run_gemini_enterprise.py --data-dir ./data --model gemini-2.5-pro # Run a single file (e.g., first 250 cases) python run_gemini_enterprise.py --data-dir ./data --files user_prompts_00.jsonl # Resume from crash (skips already-completed cases) python run_gemini_enterprise.py --data-dir ./data --resume Data files (download from thisisphume.me/prompt): system_prompt.txt — preamble + 12 few-shot examples user_prompts_XX.jsonl — 250 cases per file, 16 files total (3,753 cases) manifest.json — ground truth labels """ import json import os import time import glob import argparse def load_prompts(data_dir, files=None, method="ICL-AML"): """Load system prompt and user prompts from data directory. Methods: ICL-AML (default): system_prompt.txt (preamble + FS examples) + user_prompts (task + graph) ZS: system_prompt_ZS.txt (preamble) + graph_text + task_suffix_ZS.txt """ if method == "ZS": with open(os.path.join(data_dir, "system_prompt_ZS.txt"), "r", encoding="utf-8") as f: system_prompt = f.read() with open(os.path.join(data_dir, "task_suffix_ZS.txt"), "r", encoding="utf-8") as f: task_suffix = f.read() else: with open(os.path.join(data_dir, "system_prompt.txt"), "r", encoding="utf-8") as f: system_prompt = f.read() task_suffix = None user_prompts = [] if files: jsonl_files = [os.path.join(data_dir, f) for f in files] else: jsonl_files = sorted(glob.glob(os.path.join(data_dir, "user_prompts_*.jsonl"))) for fpath in jsonl_files: with open(fpath, "r", encoding="utf-8") as f: for line in f: entry = json.loads(line) if method == "ZS": # Extract graph text from the ICL-AML prompt # (strip the ICL-AML task header/footer, keep just the graph) prompt_text = entry["prompt"] # The graph starts with "=== Transaction Subgraph" graph_start = prompt_text.find("=== Transaction Subgraph") if graph_start >= 0: graph_text = prompt_text[graph_start:] # Strip the ICL-AML answer format at the end answer_fmt = graph_text.find("Answer Format:") if answer_fmt >= 0: graph_text = graph_text[:answer_fmt].rstrip() entry["prompt"] = graph_text + task_suffix # If can't find graph, use as-is (fallback) user_prompts.append(entry) # Load manifest for ground truth manifest_path = os.path.join(data_dir, "manifest.json") manifest = {} if os.path.exists(manifest_path): with open(manifest_path, "r", encoding="utf-8") as f: for entry in json.load(f): manifest[entry["case_id"]] = entry return system_prompt, user_prompts, manifest def run(args): import vertexai from vertexai.preview.generative_models import GenerativeModel # Init vertexai.init(project=args.project, location=args.location) generation_model = GenerativeModel(args.model) parameters = { "temperature": args.temperature, "max_output_tokens": args.max_output_tokens, "top_p": args.top_p, "top_k": args.top_k, } # Load data files = args.files.split(",") if args.files else None system_prompt, user_prompts, manifest = load_prompts(args.data_dir, files, method=args.method) print(f"Loaded {len(user_prompts)} cases, model={args.model}") # Resume support os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True) completed_ids = set() results = [] if args.resume and os.path.exists(args.output): with open(args.output, "r", encoding="utf-8") as f: existing = json.load(f) results = existing.get("results", []) completed_ids = {r["case_id"] for r in results} print(f"Resuming: {len(completed_ids)} already done, {len(user_prompts) - len(completed_ids)} remaining") # Token accumulators total_input = sum(r.get("input_tokens", 0) for r in results) total_output = sum(r.get("output_tokens", 0) for r in results) total_thinking = sum(r.get("thinking_tokens", 0) for r in results) # Run n_new = 0 for i, up in enumerate(user_prompts): case_id = up["case_id"] if case_id in completed_ids: continue gt = manifest.get(case_id, {}) t0 = time.time() try: response = generation_model.generate_content( system_prompt + up["prompt"], generation_config=parameters, stream=False, ) latency_ms = (time.time() - t0) * 1000 text = response.text usage = response.usage_metadata input_tok = usage.prompt_token_count output_tok = usage.candidates_token_count total_tok = usage.total_token_count thinking_tok = total_tok - input_tok - output_tok error = None except Exception as e: latency_ms = (time.time() - t0) * 1000 text = "" input_tok = output_tok = total_tok = thinking_tok = 0 error = str(e) total_input += input_tok total_output += output_tok total_thinking += thinking_tok n_new += 1 results.append({ "case_id": case_id, "label": gt.get("label"), "typology_gt": gt.get("typology"), "raw_response": text, "input_tokens": input_tok, "output_tokens": output_tok, "total_tokens": total_tok, "thinking_tokens": thinking_tok, "latency_ms": latency_ms, "error": error, }) # Progress if n_new % 50 == 0 or n_new == 1: done = len(completed_ids) + n_new total = len(user_prompts) remaining_cases = total - done avg_ms = sum(r["latency_ms"] for r in results[-n_new:]) / n_new eta_min = (avg_ms * remaining_cases) / 60000 print(f" [{done}/{total}] " f"in={total_input:,} out={total_output:,} think={total_thinking:,} " f"~{eta_min:.0f}m left") # Incremental save every 100 cases if n_new % 100 == 0: _save(args.output, args.model, results, total_input, total_output, total_thinking) # Final save _save(args.output, args.model, results, total_input, total_output, total_thinking) n_errors = sum(1 for r in results if r.get("error")) print(f"\nDone! {len(results)} predictions ({n_new} new), {n_errors} errors") print(f"Saved to {args.output}") def _save(output_file, model, results, total_input, total_output, total_thinking): n = len(results) with open(output_file, "w", encoding="utf-8") as f: json.dump({ "model": model, "n_predictions": n, "token_stats": { "total_input_tokens": total_input, "total_output_tokens": total_output, "total_thinking_tokens": total_thinking, "avg_input_tokens": total_input / n if n else 0, "avg_output_tokens": total_output / n if n else 0, }, "results": results, }, f, indent=2, ensure_ascii=False) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run Gemini on AML benchmark") parser.add_argument("--data-dir", default="./data", help="Directory with system_prompt.txt + user_prompts_*.jsonl") parser.add_argument("--files", default=None, help="Comma-separated JSONL files to run (default: all)") parser.add_argument("--output", default="gemini_results.json", help="Output file path") parser.add_argument("--model", default="gemini-2.5-pro") parser.add_argument("--project", default="cidat-10027-int-3532", help="GCP project ID") parser.add_argument("--location", default="us-central1") parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--max-output-tokens", type=int, default=4096) parser.add_argument("--top-p", type=float, default=0.95) parser.add_argument("--top-k", type=int, default=1) parser.add_argument("--method", default="ICL-AML", choices=["ICL-AML", "ZS"], help="Prompting method: ICL-AML (few-shot) or ZS (zero-shot)") parser.add_argument("--resume", action="store_true", help="Resume from existing output file") args = parser.parse_args() run(args)