AudioCraft: use_sampling + top_p nach offizieller Meta Doku
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
ee9636684d
commit
f46de40be4
1 changed files with 5 additions and 2 deletions
|
|
@ -36,9 +36,11 @@ class MusicGenNode:
|
|||
mg = MusicGen.get_pretrained(model)
|
||||
mg.set_generation_params(
|
||||
duration=duration,
|
||||
use_sampling=True,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg_coef,
|
||||
top_k=top_k,
|
||||
top_p=0.0,
|
||||
extend_stride=extend_stride,
|
||||
)
|
||||
|
||||
|
|
@ -90,7 +92,8 @@ class MusicGenLongNode:
|
|||
context_samples = int(context_seconds * sample_rate)
|
||||
|
||||
# Erstes Segment
|
||||
mg.set_generation_params(duration=segment_duration, temperature=temperature, cfg_coef=cfg_coef, top_k=top_k)
|
||||
gen_params = dict(use_sampling=True, temperature=temperature, cfg_coef=cfg_coef, top_k=top_k, top_p=0.0)
|
||||
mg.set_generation_params(duration=segment_duration, **gen_params)
|
||||
print(f"[MusicGenLong] Segment 1 / {int(total_duration / segment_duration) + 1}")
|
||||
first = mg.generate([prompt])
|
||||
segments = [first[0].cpu()]
|
||||
|
|
@ -101,7 +104,7 @@ class MusicGenLongNode:
|
|||
while generated < total_duration:
|
||||
remaining = total_duration - generated
|
||||
next_dur = min(segment_duration, remaining)
|
||||
mg.set_generation_params(duration=next_dur, temperature=temperature, cfg_coef=cfg_coef, top_k=top_k)
|
||||
mg.set_generation_params(duration=next_dur, **gen_params)
|
||||
|
||||
# Letzten context_seconds des vorherigen Segments als Kontext
|
||||
context = segments[-1][:, -context_samples:]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue