data_cleaning_env / dataset_factory.py
CodeKnightDebjit's picture
Upload folder using huggingface_hub
eee232c verified
"""
dataset_factory.py
------------------
Generates (dirty_df, clean_df, metadata) triples for all 3 tasks.
Key design decisions:
- Fixed random seeds per task β†’ reproducible grader scores
- clean_df is ALWAYS generated first, then dirt is injected
- metadata carries ground-truth info the grader needs (e.g. which
rows are real outliers vs valid extremes in Task 2)
- No external files needed β€” everything is generated in memory
"""
from __future__ import annotations
import copy
import random
import string
from dataclasses import dataclass, field
from typing import Any
import numpy as np
import pandas as pd
# ── Reproducible seeds ────────────────────────────────────────────────────────
SEEDS = {
"easy": 42,
"medium": 137,
"hard": 999,
}
# ── Return type ───────────────────────────────────────────────────────────────
@dataclass
class TaskDataset:
"""Everything the environment and grader need for one episode."""
task_id: str
dirty_df: pd.DataFrame
clean_df: pd.DataFrame
schema_hint: str # plain-English schema description
total_dirty_cells: int # how many cells differ at episode start
metadata: dict[str, Any] = field(default_factory=dict)
# metadata keys used by graders:
# "outlier_rows" (Task 2) β€” list of row indices that ARE true outliers
# "valid_extreme_rows" (Task 2) β€” valid rows that look extreme but must stay
# "canonical_columns" (Task 3) β€” {alias: canonical_name} mapping
# "duplicate_row_ids" (Task 3) β€” list of (original_idx, duplicate_idx) pairs
# ── Public API ────────────────────────────────────────────────────────────────
def make_dataset(task_id: str) -> TaskDataset:
"""Entry point. Call this from the environment's reset()."""
if task_id == "easy":
return _make_easy()
elif task_id == "medium":
return _make_medium()
elif task_id == "hard":
return _make_hard()
else:
raise ValueError(f"Unknown task_id: {task_id!r}. Must be easy/medium/hard.")
def count_dirty_cells(dirty_df: pd.DataFrame, clean_df: pd.DataFrame) -> int:
"""Number of cells that differ between dirty and clean DataFrames."""
# Align on same dtypes for comparison
d = dirty_df.astype(str).reset_index(drop=True)
c = clean_df.astype(str).reset_index(drop=True)
return int((d != c).sum().sum())
# ── Task 1: easy ─────────────────────────────────────────────────────────────
#
# 50-row sales CSV.
# Clean schema:
# order_id (int), customer (str), product (str), category (str),
# price (float, 2dp), quantity (int), order_date (YYYY-MM-DD),
# region (str)
#
# Injected issues (29 dirty cells total):
# β€’ 10 wrong-type cells β€” numeric column contains a word
# β€’ 8 missing values β€” NaN in various columns
# β€’ 5 bad dates β€” future year (2099-xx-xx)
# β€’ 6 whitespace cells β€” leading/trailing spaces in string columns
def _make_easy() -> TaskDataset:
rng = random.Random(SEEDS["easy"])
np_rng = np.random.default_rng(SEEDS["easy"])
n = 50
categories = ["Electronics", "Clothing", "Home", "Sports", "Books"]
regions = ["North", "South", "East", "West"]
products = ["Widget A", "Widget B", "Gadget X", "Gadget Y", "Item Z"]
customers = [f"Customer_{i:03d}" for i in range(1, 31)]
# ── Build clean DataFrame ────────────────────────────────────────────────
clean = pd.DataFrame({
"order_id": range(1001, 1001 + n),
"customer": [rng.choice(customers) for _ in range(n)],
"product": [rng.choice(products) for _ in range(n)],
"category": [rng.choice(categories) for _ in range(n)],
"price": np_rng.uniform(5.0, 500.0, n).round(2),
"quantity": np_rng.integers(1, 20, n),
"order_date": _random_dates(np_rng, n, "2023-01-01", "2024-06-30"),
"region": [rng.choice(regions) for _ in range(n)],
})
clean["price"] = clean["price"].astype(float)
clean["quantity"] = clean["quantity"].astype(int)
# ── Inject dirt ──────────────────────────────────────────────────────────
dirty = clean.copy(deep=True).astype(object)
injected: set[tuple[int, str]] = set()
def pick_fresh(col: str, exclude: set) -> int:
rows = [r for r in range(n) if (r, col) not in exclude]
return rng.choice(rows)
# 10 wrong-type cells in numeric columns
bad_words = ["N/A", "unknown", "missing", "null", "TBD", "??", "-", "n/a", "none", "β€”"]
for word, col in zip(bad_words, rng.choices(["price", "quantity"], k=10)):
row = pick_fresh(col, injected)
dirty.at[row, col] = word
injected.add((row, col))
# 8 missing values in various columns
missing_cols = rng.choices(["customer", "product", "price", "quantity", "region"], k=8)
for col in missing_cols:
row = pick_fresh(col, injected)
dirty.at[row, col] = np.nan
injected.add((row, col))
# 5 bad dates β€” far-future year
bad_date_templates = [
"2099-01-15", "2099-07-04", "2099-12-31", "2099-03-22", "2099-11-11"
]
for bad_date in bad_date_templates:
row = pick_fresh("order_date", injected)
dirty.at[row, "order_date"] = bad_date
injected.add((row, "order_date"))
# 6 whitespace cells in string columns
ws_cols = rng.choices(["customer", "product", "category", "region"], k=6)
for col in ws_cols:
row = pick_fresh(col, injected)
orig = str(dirty.at[row, col])
dirty.at[row, col] = f" {orig} "
injected.add((row, col))
dirty_cell_count = count_dirty_cells(dirty.astype(str), clean.astype(str))
schema_hint = (
"Sales orders dataset. Expected columns: "
"order_id (integer), customer (string, no leading/trailing spaces), "
"product (string, no spaces), category (one of: Electronics/Clothing/Home/Sports/Books), "
"price (float, 2 decimal places, no text), "
"quantity (integer, no text), "
"order_date (YYYY-MM-DD format, year must be 2023 or 2024), "
"region (one of: North/South/East/West, no spaces). "
"No missing values allowed."
)
return TaskDataset(
task_id="easy",
dirty_df=dirty,
clean_df=clean.astype(object),
schema_hint=schema_hint,
total_dirty_cells=dirty_cell_count,
metadata={"injected_cells": list(injected)},
)
# ── Task 2: medium ────────────────────────────────────────────────────────────
#
# 200-row customer transaction CSV.
# Clean schema:
# tx_id (int), customer_id (int), amount (float), tx_date (YYYY-MM-DD),
# category (str), country (str), status (str)
#
# Injected issues:
# β€’ 15 statistical outliers β€” amount Z-score > 4.0 (should be removed/capped)
# β€’ 5 valid extremes β€” genuinely large transactions, must NOT be removed
# β€’ 12 category typos β€” slight misspellings
def _make_medium() -> TaskDataset:
rng = random.Random(SEEDS["medium"])
np_rng = np.random.default_rng(SEEDS["medium"])
n = 200
categories = ["Food", "Electronics", "Travel", "Healthcare", "Entertainment"]
countries = ["US", "UK", "CA", "AU", "DE"]
statuses = ["completed", "pending", "refunded"]
# ── Build clean base ────────────────────────────────────────────────────
# Normal transaction amounts: mean $150, sd $60, clipped to [5, 800]
amounts = np_rng.normal(150, 60, n).clip(5, 800).round(2)
clean = pd.DataFrame({
"tx_id": range(9001, 9001 + n),
"customer_id": np_rng.integers(1, 501, n),
"amount": amounts,
"tx_date": _random_dates(np_rng, n, "2023-01-01", "2024-06-30"),
"category": [rng.choice(categories) for _ in range(n)],
"country": [rng.choice(countries) for _ in range(n)],
"status": [rng.choice(statuses) for _ in range(n)],
})
# ── Choose outlier rows (15) β€” will be injected with extreme amounts ─────
all_rows = list(range(n))
outlier_rows: list[int] = rng.sample(all_rows, 15)
remaining = [r for r in all_rows if r not in outlier_rows]
# ── Choose valid extreme rows (5) β€” large but legitimate ─────────────────
# These are NOT in outlier_rows; amounts are large (Z > 3) but real
valid_extreme_rows: list[int] = rng.sample(remaining, 5)
# ── Build dirty DataFrame ────────────────────────────────────────────────
dirty = clean.copy(deep=True).astype(object)
# Inject true outliers: very high or very low (Z > 4)
for row in outlier_rows:
if rng.random() > 0.3:
dirty.at[row, "amount"] = round(rng.uniform(5000, 15000), 2) # extreme high
else:
dirty.at[row, "amount"] = round(rng.uniform(-500, -10), 2) # negative (impossible)
# Inject valid extremes (in clean AND dirty β€” they stay)
for row in valid_extreme_rows:
valid_large = round(rng.uniform(900, 2000), 2)
clean.at[row, "amount"] = valid_large
dirty.at[row, "amount"] = valid_large
# Inject 12 category typos
typo_map: dict[str, str] = {
"Electronics": ["Electrnics", "Electronis", "Electonics"],
"Food": ["Foood", "Fod", "Fo0d"],
"Travel": ["Travle", "Trevel", "Travell"],
"Healthcare": ["Helthcare", "Healtcare", "Heathcare"],
"Entertainment": ["Entertainmnt", "Entertainmet", "Entertainmen"],
}
injected_typo_rows: set[int] = set()
typo_count = 0
typo_cells: list[tuple[int, str, str]] = [] # (row, dirty_val, clean_val)
for row in rng.sample(remaining, min(12, len(remaining))):
if typo_count >= 12:
break
if row in injected_typo_rows:
continue
orig_cat = str(clean.at[row, "category"])
misspellings = typo_map.get(orig_cat)
if misspellings:
bad = rng.choice(misspellings)
dirty.at[row, "category"] = bad
typo_cells.append((row, bad, orig_cat))
injected_typo_rows.add(row)
typo_count += 1
dirty_cell_count = count_dirty_cells(dirty.astype(str), clean.astype(str))
schema_hint = (
"Customer transactions dataset. Expected columns: "
"tx_id (integer), customer_id (integer 1–500), "
"amount (float, must be positive; realistic range is $5–$2000; "
"amounts above $2000 or below $0 are data errors), "
"tx_date (YYYY-MM-DD), "
"category (one of: Food/Electronics/Travel/Healthcare/Entertainment β€” exact spelling), "
"country (two-letter code: US/UK/CA/AU/DE), "
"status (one of: completed/pending/refunded). "
"Note: some large transactions ($900–$2000) are legitimate β€” do not remove them. "
"Only remove rows where the amount is clearly erroneous (negative or > $2000)."
)
return TaskDataset(
task_id="medium",
dirty_df=dirty,
clean_df=clean.astype(object),
schema_hint=schema_hint,
total_dirty_cells=dirty_cell_count,
metadata={
"outlier_rows": outlier_rows,
"valid_extreme_rows": valid_extreme_rows,
"typo_cells": typo_cells, # [(row, dirty_val, clean_val)]
},
)
# ── Task 3: hard ──────────────────────────────────────────────────────────────
#
# 400-row CSV merged from 3 fictional data sources.
# Each source uses different column names for the same concepts.
# Issues:
# β€’ Inconsistent column naming (3 aliases per concept)
# β€’ Mixed date formats across sources (ISO, US, EU)
# β€’ 30 duplicate rows (exact and near-duplicate)
# β€’ No schema documentation β€” agent must infer canonical form
#
# Canonical schema (what the agent must produce):
# record_id, customer_id, full_name, email, amount,
# currency, purchase_date (YYYY-MM-DD), product_name, region
_CANONICAL_COLS = [
"record_id", "customer_id", "full_name", "email",
"amount", "currency", "purchase_date", "product_name", "region",
]
# Column aliases per source
_SOURCE_ALIASES = {
"source_a": {
"record_id": "record_id",
"customer_id": "cust_id",
"full_name": "name",
"email": "email_address",
"amount": "sale_amount",
"currency": "ccy",
"purchase_date":"date",
"product_name": "item",
"region": "territory",
},
"source_b": {
"record_id": "id",
"customer_id": "customer_id",
"full_name": "full_name",
"email": "contact_email",
"amount": "value",
"currency": "currency",
"purchase_date":"purchase_date",
"product_name": "product",
"region": "area",
},
"source_c": {
"record_id": "RecordID",
"customer_id": "CustomerID",
"full_name": "CustomerName",
"email": "Email",
"amount": "Amount",
"currency": "Currency",
"purchase_date":"PurchaseDate",
"product_name": "ProductName",
"region": "Region",
},
}
# Date format used by each source
_SOURCE_DATE_FORMATS = {
"source_a": "%Y-%m-%d", # ISO: 2023-04-15
"source_b": "%m/%d/%Y", # US: 04/15/2023
"source_c": "%d.%m.%Y", # EU: 15.04.2023
}
def _make_hard() -> TaskDataset:
rng = random.Random(SEEDS["hard"])
np_rng = np.random.default_rng(SEEDS["hard"])
currencies = ["USD", "EUR", "GBP"]
regions = ["APAC", "EMEA", "AMER", "LATAM"]
products = [
"Pro Subscription", "Enterprise License", "Support Package",
"Training Course", "Hardware Bundle", "Consulting Day",
]
# Helper: generate a block of rows for one source
def _source_block(source: str, n: int, id_start: int) -> pd.DataFrame:
aliases = _SOURCE_ALIASES[source]
date_fmt = _SOURCE_DATE_FORMATS[source]
cust_ids = np_rng.integers(2001, 3001, n)
amounts = np_rng.uniform(100, 5000, n).round(2)
iso_dates = _random_dates(np_rng, n, "2022-01-01", "2024-06-30")
# Format dates in source-specific format
formatted_dates = [
pd.to_datetime(d).strftime(date_fmt)
for d in iso_dates
]
names = [_random_name(rng) for _ in range(n)]
emails = [_name_to_email(nm) for nm in names]
data = {
aliases["record_id"]: range(id_start, id_start + n),
aliases["customer_id"]: cust_ids.tolist(),
aliases["full_name"]: names,
aliases["email"]: emails,
aliases["amount"]: amounts.tolist(),
aliases["currency"]: [rng.choice(currencies) for _ in range(n)],
aliases["purchase_date"]: formatted_dates,
aliases["product_name"]: [rng.choice(products) for _ in range(n)],
aliases["region"]: [rng.choice(regions) for _ in range(n)],
}
return pd.DataFrame(data)
# Three sources, ~133 rows each (total ~400)
block_a = _source_block("source_a", 134, id_start=1)
block_b = _source_block("source_b", 133, id_start=135)
block_c = _source_block("source_c", 133, id_start=268)
# ── Canonical (clean) dataframe ─────────────────────────────────────────
def _to_canonical(df: pd.DataFrame, source: str) -> pd.DataFrame:
rev = {v: k for k, v in _SOURCE_ALIASES[source].items()}
renamed = df.rename(columns=rev)
# Normalise date to YYYY-MM-DD
renamed["purchase_date"] = pd.to_datetime(
renamed["purchase_date"],
format=_SOURCE_DATE_FORMATS[source],
).dt.strftime("%Y-%m-%d")
return renamed[_CANONICAL_COLS]
clean_a = _to_canonical(block_a, "source_a")
clean_b = _to_canonical(block_b, "source_b")
clean_c = _to_canonical(block_c, "source_c")
clean = pd.concat([clean_a, clean_b, clean_c], ignore_index=True)
clean["record_id"] = range(1, len(clean) + 1)
# ── Dirty dataframe = concat of raw source blocks ────────────────────────
# (columns are still in aliased form, dates in source-specific format)
dirty = pd.concat([block_a, block_b, block_c], ignore_index=True)
# ── Inject 30 duplicate rows ─────────────────────────────────────────────
n_clean = len(dirty)
sampled_orig = rng.sample(range(n_clean), 30)
duplicate_rows_to_inject: list[pd.DataFrame] = []
duplicate_pairs: list[tuple[int, int]] = []
for orig_idx in sampled_orig:
dup = dirty.iloc[[orig_idx]].copy()
# Near-duplicate: 40% chance of a minor field change
if rng.random() < 0.4:
# Slightly alter the amount (Β±1%)
col_amount = list(_SOURCE_ALIASES["source_a"].values())[4] # 'sale_amount'
# Find which column name is 'amount-like' in this row's source
# Since we concat all sources, each row might have NaN in other sources' cols.
# Simpler: just modify the raw value in the only non-null amount column.
for amt_col in ["sale_amount", "value", "Amount"]:
if amt_col in dup.columns and pd.notna(dup.iloc[0].get(amt_col)):
old_val = dup.at[dup.index[0], amt_col]
dup.at[dup.index[0], amt_col] = round(float(old_val) * rng.uniform(0.99, 1.01), 2)
break
duplicate_rows_to_inject.append(dup)
duplicate_pairs.append((orig_idx, n_clean + len(duplicate_pairs)))
dirty = pd.concat([dirty] + duplicate_rows_to_inject, ignore_index=True)
# Shuffle so duplicates aren't obviously at the bottom
dirty = dirty.sample(frac=1, random_state=SEEDS["hard"]).reset_index(drop=True)
# Build canonical alias lookup for grader
canonical_lookup: dict[str, str] = {}
for source, aliases in _SOURCE_ALIASES.items():
for canonical, alias in aliases.items():
canonical_lookup[alias] = canonical
dirty_cell_count = len(dirty) * len(_CANONICAL_COLS) # hard task: whole-df scope
schema_hint = (
"Merged dataset from 3 sources with inconsistent schemas. "
"Your goal is to produce a single clean DataFrame with these canonical columns: "
"record_id (integer, unique), customer_id (integer), full_name (string), "
"email (string), amount (float), currency (one of: USD/EUR/GBP), "
"purchase_date (YYYY-MM-DD), product_name (string), region (one of: APAC/EMEA/AMER/LATAM). "
"Column names in the raw data vary by source (e.g. 'cust_id', 'customer_id', 'CustomerID' "
"all mean customer_id). Date formats also vary (ISO, US MM/DD/YYYY, EU DD.MM.YYYY). "
"There are also ~30 duplicate rows (some exact, some near-duplicate). "
"Remove duplicates, normalise all column names and date formats."
)
return TaskDataset(
task_id="hard",
dirty_df=dirty,
clean_df=clean.astype(object),
schema_hint=schema_hint,
total_dirty_cells=dirty_cell_count,
metadata={
"canonical_columns": _CANONICAL_COLS,
"canonical_lookup": canonical_lookup, # alias β†’ canonical name
"source_aliases": _SOURCE_ALIASES,
"source_date_formats": _SOURCE_DATE_FORMATS,
"duplicate_pairs": duplicate_pairs, # (original_idx, dup_idx) in pre-shuffle dirty
"n_clean_rows": len(clean),
},
)
# ── Internal helpers ──────────────────────────────────────────────────────────
def _random_dates(
rng: np.random.Generator,
n: int,
start: str,
end: str,
) -> list[str]:
"""Generate n random ISO-format date strings between start and end."""
start_ts = pd.Timestamp(start)
end_ts = pd.Timestamp(end)
delta_days = (end_ts - start_ts).days
offsets = rng.integers(0, delta_days, n)
return [
(start_ts + pd.Timedelta(days=int(d))).strftime("%Y-%m-%d")
for d in offsets
]
_FIRST_NAMES = [
"Alice", "Bob", "Carol", "David", "Eva", "Frank", "Grace", "Henry",
"Iris", "Jack", "Karen", "Leo", "Mia", "Nathan", "Olivia", "Paul",
"Quinn", "Rosa", "Sam", "Tara", "Uma", "Victor", "Wendy", "Xavier",
"Yuki", "Zara",
]
_LAST_NAMES = [
"Smith", "Jones", "Williams", "Brown", "Taylor", "Davies", "Evans",
"Wilson", "Thomas", "Roberts", "Johnson", "Lee", "Martin", "Garcia",
"Martinez", "Anderson", "Thompson", "White", "Harris", "Clark",
]
def _random_name(rng: random.Random) -> str:
return f"{rng.choice(_FIRST_NAMES)} {rng.choice(_LAST_NAMES)}"
def _name_to_email(name: str) -> str:
first, last = name.lower().split()
domains = ["example.com", "mail.com", "inbox.net", "corp.io"]
return f"{first}.{last}@{domains[hash(name) % len(domains)]}"
# ── Smoke test ────────────────────────────────────────────────────────────────
if __name__ == "__main__":
for task_id in ("easy", "medium", "hard"):
ds = make_dataset(task_id)
print(f"\n{'─'*60}")
print(f"Task: {task_id.upper()}")
print(f" dirty shape : {ds.dirty_df.shape}")
print(f" clean shape : {ds.clean_df.shape}")
print(f" dirty cells : {ds.total_dirty_cells}")
print(f" schema hint : {ds.schema_hint[:80]}…")
print(f" metadata keys: {list(ds.metadata.keys())}")
if task_id == "easy":
print(f"\n Sample dirty rows (price/quantity col):")
mask = ds.dirty_df["price"].astype(str).str.contains(
r"[a-zA-Z]|nan", na=True
)
print(ds.dirty_df[mask][["order_id","price","quantity"]].head(3).to_string(index=False))
if task_id == "medium":
print(f"\n Outlier rows (first 5): {ds.metadata['outlier_rows'][:5]}")
print(f" Valid extreme rows: {ds.metadata['valid_extreme_rows']}")
if task_id == "hard":
print(f"\n Raw column names: {list(ds.dirty_df.columns)}")
print(f" Duplicate pairs (first 3): {ds.metadata['duplicate_pairs'][:3]}")