AudioCraft: use_sampling + top_p nach offizieller Meta Doku
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
ee9636684d
commit
f46de40be4
1 changed files with 5 additions and 2 deletions
|
|
@ -36,9 +36,11 @@ class MusicGenNode:
|
||||||
mg = MusicGen.get_pretrained(model)
|
mg = MusicGen.get_pretrained(model)
|
||||||
mg.set_generation_params(
|
mg.set_generation_params(
|
||||||
duration=duration,
|
duration=duration,
|
||||||
|
use_sampling=True,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
cfg_coef=cfg_coef,
|
cfg_coef=cfg_coef,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
top_p=0.0,
|
||||||
extend_stride=extend_stride,
|
extend_stride=extend_stride,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -90,7 +92,8 @@ class MusicGenLongNode:
|
||||||
context_samples = int(context_seconds * sample_rate)
|
context_samples = int(context_seconds * sample_rate)
|
||||||
|
|
||||||
# Erstes Segment
|
# Erstes Segment
|
||||||
mg.set_generation_params(duration=segment_duration, temperature=temperature, cfg_coef=cfg_coef, top_k=top_k)
|
gen_params = dict(use_sampling=True, temperature=temperature, cfg_coef=cfg_coef, top_k=top_k, top_p=0.0)
|
||||||
|
mg.set_generation_params(duration=segment_duration, **gen_params)
|
||||||
print(f"[MusicGenLong] Segment 1 / {int(total_duration / segment_duration) + 1}")
|
print(f"[MusicGenLong] Segment 1 / {int(total_duration / segment_duration) + 1}")
|
||||||
first = mg.generate([prompt])
|
first = mg.generate([prompt])
|
||||||
segments = [first[0].cpu()]
|
segments = [first[0].cpu()]
|
||||||
|
|
@ -101,7 +104,7 @@ class MusicGenLongNode:
|
||||||
while generated < total_duration:
|
while generated < total_duration:
|
||||||
remaining = total_duration - generated
|
remaining = total_duration - generated
|
||||||
next_dur = min(segment_duration, remaining)
|
next_dur = min(segment_duration, remaining)
|
||||||
mg.set_generation_params(duration=next_dur, temperature=temperature, cfg_coef=cfg_coef, top_k=top_k)
|
mg.set_generation_params(duration=next_dur, **gen_params)
|
||||||
|
|
||||||
# Letzten context_seconds des vorherigen Segments als Kontext
|
# Letzten context_seconds des vorherigen Segments als Kontext
|
||||||
context = segments[-1][:, -context_samples:]
|
context = segments[-1][:, -context_samples:]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue