|
| 1 | +import math |
| 2 | +import random |
| 3 | +import time |
| 4 | + |
| 5 | +from PIL import Image as PILImage |
| 6 | +from PIL import ImageDraw |
| 7 | + |
| 8 | +import trackio as wandb |
| 9 | + |
| 10 | +EPOCHS = 12 |
| 11 | +W = H = 128 |
| 12 | + |
| 13 | + |
| 14 | +def lissajous(t, w=W, h=H): |
| 15 | + x = (w // 2) + int((w // 3) * math.sin(2.0 * t)) |
| 16 | + y = (h // 2) + int((h // 3) * math.sin(3.0 * t + math.pi / 4)) |
| 17 | + return x, y |
| 18 | + |
| 19 | + |
| 20 | +def render_overlay(target_xy, pred_xy): |
| 21 | + img = PILImage.new("RGB", (W, H), "black") |
| 22 | + draw = ImageDraw.Draw(img) |
| 23 | + |
| 24 | + tx, ty = target_xy |
| 25 | + draw.ellipse((tx - 5, ty - 5, tx + 5, ty + 5), fill=(0, 255, 0)) |
| 26 | + |
| 27 | + px, py = pred_xy |
| 28 | + draw.ellipse((px - 5, py - 5, px + 5, py + 5), fill=(255, 80, 80)) |
| 29 | + |
| 30 | + draw.line([(tx, ty), (px, py)], fill=(255, 255, 0), width=1) |
| 31 | + return img |
| 32 | + |
| 33 | + |
| 34 | +def main(): |
| 35 | + project_id = random.randint(10000, 99999) |
| 36 | + project_name = f"image-logging-demo-{project_id}" |
| 37 | + |
| 38 | + for run_index in range(1, 3): |
| 39 | + run_name = f"image-run-{run_index}" |
| 40 | + wandb.init(project=project_name, name=run_name) |
| 41 | + |
| 42 | + pred_x, pred_y = random.randint(0, W - 1), random.randint(0, H - 1) |
| 43 | + |
| 44 | + for epoch in range(EPOCHS): |
| 45 | + t = epoch / 3.0 |
| 46 | + target = lissajous(t) |
| 47 | + |
| 48 | + lr = 0.35 |
| 49 | + noise_scale = max(0.0, 5.0 * (1.0 - epoch / (EPOCHS - 1))) |
| 50 | + pred_x += lr * (target[0] - pred_x) + random.uniform( |
| 51 | + -noise_scale, noise_scale |
| 52 | + ) |
| 53 | + pred_y += lr * (target[1] - pred_y) + random.uniform( |
| 54 | + -noise_scale, noise_scale |
| 55 | + ) |
| 56 | + pred = (int(round(pred_x)), int(round(pred_y))) |
| 57 | + |
| 58 | + loss = math.dist(target, pred) |
| 59 | + |
| 60 | + overlay = render_overlay(target, pred) |
| 61 | + wandb.log( |
| 62 | + { |
| 63 | + "loss": loss, |
| 64 | + "target_x": target[0], |
| 65 | + "target_y": target[1], |
| 66 | + "pred_x": pred[0], |
| 67 | + "pred_y": pred[1], |
| 68 | + "overlay": wandb.Image( |
| 69 | + overlay, caption=f"step={epoch}, loss={loss:.2f}" |
| 70 | + ), |
| 71 | + }, |
| 72 | + step=epoch, |
| 73 | + ) |
| 74 | + time.sleep(0.2) |
| 75 | + |
| 76 | + wandb.finish() |
| 77 | + |
| 78 | + |
| 79 | +if __name__ == "__main__": |
| 80 | + main() |
0 commit comments