AudioCraft: use_sampling + top_p nach offizieller Meta Doku

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Andre 2026-03-21 15:38:03 +01:00
parent ee9636684d
commit f46de40be4

View file

@ -36,9 +36,11 @@ class MusicGenNode:
mg = MusicGen.get_pretrained(model)
mg.set_generation_params(
duration=duration,
use_sampling=True,
temperature=temperature,
cfg_coef=cfg_coef,
top_k=top_k,
top_p=0.0,
extend_stride=extend_stride,
)
@ -90,7 +92,8 @@ class MusicGenLongNode:
context_samples = int(context_seconds * sample_rate)
# Erstes Segment
mg.set_generation_params(duration=segment_duration, temperature=temperature, cfg_coef=cfg_coef, top_k=top_k)
gen_params = dict(use_sampling=True, temperature=temperature, cfg_coef=cfg_coef, top_k=top_k, top_p=0.0)
mg.set_generation_params(duration=segment_duration, **gen_params)
print(f"[MusicGenLong] Segment 1 / {int(total_duration / segment_duration) + 1}")
first = mg.generate([prompt])
segments = [first[0].cpu()]
@ -101,7 +104,7 @@ class MusicGenLongNode:
while generated < total_duration:
remaining = total_duration - generated
next_dur = min(segment_duration, remaining)
mg.set_generation_params(duration=next_dur, temperature=temperature, cfg_coef=cfg_coef, top_k=top_k)
mg.set_generation_params(duration=next_dur, **gen_params)
# Letzten context_seconds des vorherigen Segments als Kontext
context = segments[-1][:, -context_samples:]