Skip to content

Commit

Permalink
Add prompting support (#25)
Browse files Browse the repository at this point in the history
* Add prompting support #8 (upd to 1.3.0)

* fix C# 9 target-typed new (for Unity 2019/2020)

* fixes (#25)
  • Loading branch information
SharafeevRavil committed Jun 17, 2023
1 parent 8532481 commit 269eea8
Show file tree
Hide file tree
Showing 8 changed files with 2,000 additions and 207 deletions.
2,017 changes: 1,814 additions & 203 deletions Assets/Samples/1 - Audio Clip/1 - Audio Clip.unity

Large diffs are not rendered by default.

48 changes: 48 additions & 0 deletions Assets/Samples/1 - Audio Clip/AudioClipDemo.cs
Original file line number Diff line number Diff line change
@@ -1,17 +1,51 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using UnityEngine;
using UnityEngine.UI;
// ReSharper disable ArrangeObjectCreationWhenTypeEvident - for Unity 2019/2020 support:

namespace Whisper.Samples
{
public class AudioClipDemo : MonoBehaviour
{
[Serializable]
public class InitialPrompt
{
public string name;
public string prompt;
}

public WhisperManager manager;
public AudioClip clip;
public bool echoSound = true;

public List<InitialPrompt> initialPrompts = new List<InitialPrompt>
{
new InitialPrompt
{
name = "lowercase",
prompt = "hello how is it going always use lowercase no punctuation goodbye one two three start stop i you me they",
},
new InitialPrompt
{
name = "Start of the clip",
prompt = "And so my fellow Americans, ask not what your country can do for you",
},
new InitialPrompt
{
name = "UPPERCASE",
prompt = "HELLO HOW IS IT GOING ALWAYS USE UPPERCASE NO PUNCTUATION GOODBYE ONE TWO THREE START STOP I YOU ME THEY",
},
new InitialPrompt
{
name = "Custom",
prompt = "",
},
};

[Header("Text Output")]
public bool streamSegments = true;
public bool printLanguage = true;
Expand All @@ -21,6 +55,8 @@ public class AudioClipDemo : MonoBehaviour
public Button button;
public Text outputText;
public Text timeText;
public Dropdown initialPromptDropdown;
public InputField selectedInitialPromptInput;

private string _buffer;

Expand All @@ -29,6 +65,13 @@ private void Awake()
button.onClick.AddListener(ButtonPressed);
if (streamSegments)
manager.OnNewSegment += OnNewSegmentHandler;

initialPromptDropdown.options = initialPrompts
.Select(x => new Dropdown.OptionData(x.name))
.ToList();
initialPromptDropdown.onValueChanged.AddListener(OnInitialPromptChanged);
initialPromptDropdown.value = 0;
OnInitialPromptChanged(initialPromptDropdown.value);
}

private void OnDestroy()
Expand All @@ -37,12 +80,17 @@ private void OnDestroy()
manager.OnNewSegment -= OnNewSegmentHandler;
}

private void OnInitialPromptChanged(int ind) => selectedInitialPromptInput.text = initialPrompts[ind].prompt;

public async void ButtonPressed()
{
_buffer = "";
if (echoSound)
AudioSource.PlayClipAtPoint(clip, Vector3.zero);

// set initial prompt in manager
manager.initialPrompt = selectedInitialPromptInput.text;

var sw = new Stopwatch();
sw.Start();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public unsafe struct WhisperNativeParams

// tokens to provide to the whisper decoder as initial prompt
// these are prepended to any existing text context from a previous call
byte* initial_prompt;
public byte* initial_prompt;
whisper_token_ptr prompt_tokens;
int prompt_n_tokens;

Expand Down Expand Up @@ -137,7 +137,7 @@ struct beam_search_struct
// called for every newly generated text segment
public whisper_new_segment_callback new_segment_callback;
public System.IntPtr new_segment_callback_user_data;

// called on each progress update
void* progress_callback;
void* progress_callback_user_data;
Expand Down
7 changes: 7 additions & 0 deletions Packages/com.whisper.unity/Runtime/WhisperManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ public class WhisperManager : MonoBehaviour

[Tooltip("Output tokens with their confidence in each segment.")]
public bool enableTokens;

[Tooltip("Initial prompt as a string variable. " +
"It should improve transcription quality or guide it to the right direction.")]
[TextArea]
public string initialPrompt;


[Header("Experimental settings")]
[Tooltip("[EXPERIMENTAL] Output timestamps for each token. Need enabled tokens to work.")]
Expand Down Expand Up @@ -152,6 +158,7 @@ private void UpdateParams()
_params.AudioCtx = audioCtx;
_params.EnableTokens = enableTokens;
_params.TokenTimestamps = tokensTimestamps;
_params.InitialPrompt = initialPrompt;
}

private async Task<bool> CheckIfLoaded()
Expand Down
47 changes: 45 additions & 2 deletions Packages/com.whisper.unity/Runtime/WhisperParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ public class WhisperParams
private WhisperNativeParams _param;
private string _languageManaged;
private IntPtr _languagePtr = IntPtr.Zero;
private string _initialPromptManaged;
private IntPtr _initialPromptPtr = IntPtr.Zero;

/// <summary>
/// Native C++ struct parameters.
Expand All @@ -26,8 +28,12 @@ private unsafe WhisperParams(WhisperNativeParams param)
_param = param;

// copy language string to managed memory
var strPtr = new IntPtr(param.language);
_languageManaged = Marshal.PtrToStringAnsi(strPtr);
var languageStrPtr = new IntPtr(param.language);
_languageManaged = Marshal.PtrToStringAnsi(languageStrPtr);

// copy initial_prompt string to managed memory
var initialPromptStrPtr = new IntPtr(param.initial_prompt);
_initialPromptManaged = Marshal.PtrToStringAnsi(initialPromptStrPtr);

// reset callbacks
_param.new_segment_callback = null;
Expand All @@ -37,6 +43,7 @@ private unsafe WhisperParams(WhisperNativeParams param)
~WhisperParams()
{
FreeLanguageString();
FreeInitialPromptString();
}

#region Basic Parameters
Expand Down Expand Up @@ -196,6 +203,32 @@ public string Language
}
}
}

/// <summary>
/// initial prompt is converted to tokens and prepended to any existing text context from a previous call
/// <a href="https://github.com/ggerganov/whisper.cpp/discussions/348#discussioncomment-4559682">Using example</a>
/// </summary>
public string InitialPrompt
{
get => _initialPromptManaged;
set
{
if (_initialPromptManaged == value)
return;

_initialPromptManaged = value;
unsafe
{
// free previous string
FreeInitialPromptString();

// copies string in unmanaged memory to avoid GC
if (_initialPromptManaged == null) return;
_initialPromptPtr = Marshal.StringToHGlobalAnsi(_initialPromptManaged);
_param.initial_prompt = (byte*)_initialPromptPtr;
}
}
}

#endregion

Expand Down Expand Up @@ -281,6 +314,16 @@ private void FreeLanguageString()
_languagePtr = IntPtr.Zero;
}

private void FreeInitialPromptString()
{
// if C# allocated new string before - clear it
// but only clear C# string, not C++ literals
// this code assumes that whisper will not change initial prompt string in C++
if (_initialPromptPtr != IntPtr.Zero)
Marshal.FreeHGlobal(_initialPromptPtr);
_initialPromptPtr = IntPtr.Zero;
}

public static WhisperParams GetDefaultParams(WhisperSamplingStrategy strategy =
WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY)
{
Expand Down
21 changes: 21 additions & 0 deletions Packages/com.whisper.unity/Tests/Runtime/WhisperParamsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,26 @@ public void LanguageParamsTest()
param.Language = "de";
Assert.AreEqual("de", param.Language);
}

[Test]
public void PromptParamsTest()
{
var param = WhisperParams.GetDefaultParams();
Assert.NotNull(param);

// check get default prompt
Assert.DoesNotThrow(() => { var tmp = param.InitialPrompt; });

// check no prompt provided
param.InitialPrompt = "";
Assert.AreEqual("", param.InitialPrompt);
param.InitialPrompt = null;
Assert.AreEqual(null, param.InitialPrompt);

// check prompt changing
const string constPrompt = "hello how is it going always use lowercase no punctuation goodbye one two three start stop i you me they";
param.InitialPrompt = constPrompt;
Assert.AreEqual(constPrompt, param.InitialPrompt);
}
}
}
60 changes: 60 additions & 0 deletions Packages/com.whisper.unity/Tests/Runtime/WhisperPromptTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
using System.IO;
using NUnit.Framework;
using UnityEngine;

namespace Whisper.Tests
{
[TestFixture]
public class WhisperPromptTests
{
private readonly string _modelPath = Path.Combine(Application.streamingAssetsPath, "Whisper/ggml-tiny.bin");
private readonly float[] _buffer = new float[32000];
private const int Frequency = 8000;
private const int Channels = 2;

private WhisperWrapper _whisper;
private WhisperParams _params;

[SetUp]
public void Setup()
{
_whisper = WhisperWrapper.InitFromFile(_modelPath);
_params = WhisperParams.GetDefaultParams();
}

[Test]
public void RunWithNullPrompt()
{
_params.InitialPrompt = null;

var res = _whisper.GetText(_buffer, Frequency, Channels, _params);
Assert.NotNull(res);
}

[Test]
public void RunWithEmptyPrompt()
{
_params.InitialPrompt = "";

var res = _whisper.GetText(_buffer, Frequency, Channels, _params);
Assert.NotNull(res);
}

[Test]
public void TextDiffersDueToPrompt()
{
var clip = AudioClip.Create("test", _buffer.Length, Channels, Frequency, false);

var res1 = _whisper.GetText(clip, _params);
Assert.NotNull(res1);

_params.InitialPrompt = "hello how is it going always use lowercase no punctuation goodbye one two three start stop i you me they" +
" EVERY WORD IS WRITTEN IN CAPITAL LETTERS AS IF THE CAPS LOCK KEY WAS PRESSED" +
". This long prompt should change the result a lot, i hope!";
var res2 = _whisper.GetText(clip, _params);
Assert.NotNull(res2);

Assert.True(res1.Result != res2.Result);
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 269eea8

Please sign in to comment.