# baby_discriminator
## Challenge Overview
We need to distinguish between two types of random vectors 200 times in a row to get the flag.
- Bit 1: The vector (length 140) is completely random. Each element is generated independently.
- Bit 0: The vector has a hidden structure. Each element
v[j](forj >= 5) is generated using a PRNG seeded with the previous 5 elements (v[j-5:j]). However, to make it harder, some elements are randomly replaced with noise.
We are provided with a baby-discriminator.py file:
bashimport random from hashlib import md5, sha256 import secrets import string import numpy as np import sys try: from secret import flag except ImportError: flag = "0ops{this_is_a_test_flag}" window = 5 total_nums = 20000 vector_size = 140 def proof_of_work(): challenge = ''.join(secrets.choice(string.ascii_letters + string.digits) for _ in range(8)) difficulty = 6 print(f"Proof of Work challenge:") print(f"sha256({challenge} + ???) starts with {'0' * difficulty}") sys.stdout.write("Enter your answer: ") sys.stdout.flush() answer = sys.stdin.readline().strip() hash_res = sha256((challenge + answer).encode()).hexdigest() if hash_res.startswith('0' * difficulty): return True return False def choose_one(seed = None): p_v = 10 ** np.random.uniform(0, 13, size=total_nums) if seed is not None: seed_int = int(seed, 16) rng = np.random.default_rng(seed_int) else: rng = np.random.default_rng() us = rng.random(total_nums) return int(np.argmax(np.log(us) / p_v)) def get_vector(bit): if bit == 0: v = [] for _ in range(vector_size): seed = md5(str(v[-window:]).encode()).hexdigest() if len(v) >= window else None v.append(choose_one(seed)) to_change = secrets.randbelow(65) pos = random.choices(range(vector_size), k=to_change) for p in pos: v[p] = choose_one() return v else: return [choose_one() for _ in range(vector_size)] if not proof_of_work(): print("PoW verification failed!") exit() banner = """ █████ ███ █████ ████ ████ █████ █████ █████████ ███████████ ███████████ ░░███ ░███ ░░███ ░░███ ░░███ ░░███ ███░░░███ ███░░░░░███░█░░░███░░░█░░███░░░░░░█ ░███ ░███ ░███ ██████ ░███ ░███ ██████ ██████ █████████████ ███████ ██████ ███ ░░███ ███ ░░░ ░ ░███ ░ ░███ █ ░ ░███ ░███ ░███ ███░░███ ░███ ░███ ███░░███ ███░░███░░███░░███░░███ ░░░███░ ███░░███ ░███ ░███░███ ░███ ░███████ ░░███ █████ ███ ░███████ ░███ ░███ ░███ ░░░ ░███ ░███ ░███ ░███ ░███ ░███ ░███ ░███ ░███ ░███░███ ░███ ░███░░░█ ░░░█████░█████░ ░███░░░ ░███ ░███ ░███ ███░███ ░███ ░███ ░███ ░███ ░███ ███░███ ░███ ░░███ ███ ░░███ ███ ░███ ░███ ░ ░░███ ░░███ ░░██████ █████ █████░░██████ ░░██████ █████░███ █████ ░░█████ ░░██████ ░░░█████░ ░░█████████ █████ █████ ░░░ ░░░ ░░░░░░ ░░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░░ ░░░ ░░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░░░░░░ ░░░░░ ░░░░░ """ print(banner) print("Are u ready to play the game") play_times = 200 for i in range(play_times): bit = secrets.randbelow(2) v = get_vector(bit) print("Vector: ", v) print("Please tell me the bit of the vector") try: user_bit = int(input()) except ValueError: print("Invalid input") exit() if user_bit != bit: print("Wrong answer") exit() print("You are a good guesser, the flag is ", flag)
## The Generation Process
The core function is choose_one(seed):
- It generates random weights
p_v. - It generates random values
us. - It picks an index
kthat maximizeslog(us[k]) / p_v[k].
Crucially:
p_vis always generated using the global RNG state (which we don't know).usis generated using a local RNG. If aseedis provided,usis deterministic and known to us!
For Bit 0, v[j] is generated with seed = md5(v[j-5:j]). This means if we know the previous 5 numbers, we can reconstruct the exact us array that was used to generate v[j].
## The Statistical Flaw
Even though we don't know p_v, we know that the winning index v[j] maximizes log(us[k]) / p_v[k].
Since log(us) is negative and p_v is positive, we want the value closest to 0. This happens when us[k] is large (close to 1) and p_v[k] is large.
Therefore, the winning index v[j] tends to have a larger than average us[v[j]] value.
If v[j] was generated from the seed md5(v[j-5:j]):
- We can re-generate the same
usarray. - We look at
us[v[j]]. It should be one of the larger values inus. - We compute the rank
q: the fraction ofusvalues that are larger thanus[v[j]].- If
v[j]is the true generated value,qwill be small (close to 0). - If
v[j]is random (or replaced by noise),us[v[j]]will be uncorrelated with our generatedus, soqwill be uniformly distributed between 0 and 1.
- If
## The Solution
We can build a distinguisher based on this q value.
-
Calculate
qfor every position:
For each indexjfrom 5 to 139:- Calculate seed:
md5(v[j-5:j]). - Generate
ususing this seed. - Compute
q = (count of us > us[v[j]]) / total_count.
- Calculate seed:
-
Analyze the distribution of
q:- Bit 1 (Random): The
qvalues will look like a uniform distribution. - Bit 0 (Structured): Many
qvalues will be very small (e.g., < 0.001), corresponding to the positions that weren't replaced by noise.
- Bit 1 (Random): The
-
Classification:
We can use a simple heuristic or a trained classifier.- Heuristic: If we see a "run" of very small
qvalues (e.g., 2 consecutiveq < 0.001), it's almost certainly Bit 0. Random chance of this is . - Classifier: The provided solver uses a Logistic Regression model trained on features like "mean of the smallest 5
qs", "count ofq < 0.001", etc.
- Heuristic: If we see a "run" of very small
## Implementation Details
The solver script solve.py implements this:
- Solves the PoW.
- For each round:
- Computes
qvalues for the vector. - Extracts features (sorted
qstats). - Checks for a "run" of small
qs (strong signal). - If no run, uses the logistic regression score.
- Sends the guess.
- Computes
Flag: 0ops{34sy_st4tistics_g@me_Thou_@rt_more_lovely_and_more_temperate}
import ast
import hashlib
import math
import socket
import time
import sys
import numpy as np
if len(sys.argv) > 2:
HOST = sys.argv[1]
PORT = int(sys.argv[2])
else:
HOST = "127.0.0.1"
PORT = 1337
WINDOW = 5
TOTAL_NUMS = 20000
# Logistic model trained on simulated vectors (features from sorted q values).
WEIGHTS = [
-1.67498857e-04,
2.17279932e-03,
-7.82586101e-03,
-1.42188400e-02,
-2.13180118e-02,
-1.69755344e-02,
-1.38820575e-02,
-4.97286329e-01,
-5.37879989e-01,
-1.71103834e00,
-4.49118522e-01,
]
INTERCEPT = 7.45160422
SCORE_THRESHOLD = 0.6
RUN_THRESHOLD = 0.001
RUN_LENGTH = 2
def solve_pow(challenge: str) -> str:
prefix = challenge.encode()
target = b"\x00\x00\x00" # difficulty=6 hex zeros
i = 0
while True:
s = str(i).encode()
if hashlib.sha256(prefix + s).digest().startswith(target):
return s.decode()
i += 1
def vector_features(vector):
qs = []
for j in range(WINDOW, len(vector)):
seed = hashlib.md5(str(vector[j - WINDOW : j]).encode()).hexdigest()
rng = np.random.default_rng(int(seed, 16))
us = rng.random(TOTAL_NUMS)
u_obs = us[vector[j]]
q = float(np.mean(us >= u_obs))
qs.append(q)
qs = np.array(qs)
qs_sorted = np.sort(qs)
feats = [
float(qs_sorted[0]),
float(qs_sorted[:2].mean()),
float(qs_sorted[:3].mean()),
float(qs_sorted[:5].mean()),
float(qs_sorted[:10].mean()),
float(qs_sorted[:20].mean()),
float(qs_sorted.mean()),
int(np.sum(qs_sorted < 0.0005)),
int(np.sum(qs_sorted < 0.001)),
int(np.sum(qs_sorted < 0.002)),
int(np.sum(qs_sorted < 0.005)),
]
return qs, feats
def max_run_below(qs, threshold):
max_run = 0
cur = 0
for q in qs:
if q < threshold:
cur += 1
if cur > max_run:
max_run = cur
else:
cur = 0
return max_run
def classify_vector(vector):
qs, feats = vector_features(vector)
z = INTERCEPT + sum(w * x for w, x in zip(WEIGHTS, feats))
score = 1.0 / (1.0 + math.exp(-z))
run_len = max_run_below(qs, RUN_THRESHOLD)
# Strong indicator: at least two consecutive tiny-q values -> bit0.
if run_len >= RUN_LENGTH:
return 0
return 0 if score < SCORE_THRESHOLD else 1
def extract_challenge(line):
marker = "sha256("
start = line.find(marker)
if start == -1:
return None
rest = line[start + len(marker) :]
sep = " + ???"
end = rest.find(sep)
if end == -1:
return None
return rest[:end].strip()
def solve_once():
try:
with socket.create_connection((HOST, PORT), timeout=10) as sock:
sock.settimeout(15)
f = sock.makefile("rwb", buffering=0)
while True:
data = f.readline()
if not data:
return False
line = data.decode(errors="ignore").strip()
if not line:
continue
challenge = extract_challenge(line)
if challenge:
print("Solving PoW...", flush=True)
answer = solve_pow(challenge)
f.write((answer + "\n").encode())
continue
if line.startswith("Vector:"):
vec_str = line.split("Vector:", 1)[1].strip()
vector = ast.literal_eval(vec_str)
bit = classify_vector(vector)
f.write((str(bit) + "\n").encode())
continue
print(line)
if "Wrong answer" in line:
return False
if "flag is" in line:
return True
except (OSError, socket.timeout) as exc:
print(f"Connection error: {exc}", flush=True)
return False
def main():
for attempt in range(1, 11):
print(f"Attempt {attempt}/10", flush=True)
ok = solve_once()
if ok:
return
print("Retrying...", flush=True)
time.sleep(1)
if __name__ == "__main__":
main()Comments(0)
No comments yet. Be the first to share your thoughts!