reworked results chpt
This commit is contained in:
@@ -143,6 +143,11 @@ def _dynamic_ylim(all_vals: List[float], all_errs: List[float]) -> Tuple[float,
|
||||
return (float(y0), float(y1))
|
||||
|
||||
|
||||
def _get_dim_mapping(dims: list[int]) -> dict[int, int]:
|
||||
"""Map actual dimensions to evenly spaced positions (0, 1, 2, ...)"""
|
||||
return {dim: i for i, dim in enumerate(dims)}
|
||||
|
||||
|
||||
def plot_eval(ev: str, agg: Dict[Tuple[str, str, int, int, int], Agg], outdir: Path):
|
||||
fig, axes = plt.subplots(
|
||||
len(SEMI_LABELING_REGIMES),
|
||||
@@ -155,6 +160,9 @@ def plot_eval(ev: str, agg: Dict[Tuple[str, str, int, int, int], Agg], outdir: P
|
||||
if len(SEMI_LABELING_REGIMES) == 1:
|
||||
axes = [axes]
|
||||
|
||||
# Create dimension mapping
|
||||
dim_mapping = _get_dim_mapping(LATENT_DIMS)
|
||||
|
||||
for ax, regime in zip(axes, SEMI_LABELING_REGIMES):
|
||||
semi_n, semi_a = regime
|
||||
data = {}
|
||||
@@ -163,7 +171,9 @@ def plot_eval(ev: str, agg: Dict[Tuple[str, str, int, int, int], Agg], outdir: P
|
||||
for dim in LATENT_DIMS:
|
||||
key = (ev, net, dim, semi_n, semi_a)
|
||||
if key in agg:
|
||||
xs.append(dim)
|
||||
xs.append(
|
||||
dim_mapping[dim]
|
||||
) # Use mapped position instead of actual dim
|
||||
ys.append(agg[key].mean)
|
||||
es.append(agg[key].std)
|
||||
data[net] = (xs, ys, es)
|
||||
@@ -172,12 +182,16 @@ def plot_eval(ev: str, agg: Dict[Tuple[str, str, int, int, int], Agg], outdir: P
|
||||
xs, ys, es = data[net]
|
||||
if not xs:
|
||||
continue
|
||||
ax.set_xticks(LATENT_DIMS)
|
||||
ax.yaxis.set_major_locator(MaxNLocator(nbins=5)) # e.g., always 5 ticks
|
||||
|
||||
# Set evenly spaced ticks with actual dimension labels
|
||||
ax.set_xticks(list(dim_mapping.values()))
|
||||
ax.set_xticklabels(LATENT_DIMS)
|
||||
|
||||
ax.yaxis.set_major_locator(MaxNLocator(nbins=5))
|
||||
ax.scatter(
|
||||
xs, ys, s=35, color=color, alpha=SCATTER_ALPHA, label=f"{net} (points)"
|
||||
)
|
||||
x_fit, y_fit = _lin_trend(xs, ys)
|
||||
x_fit, y_fit = _lin_trend(xs, ys) # Now using mapped positions
|
||||
ax.plot(
|
||||
x_fit,
|
||||
y_fit,
|
||||
|
||||
Reference in New Issue
Block a user