Skip to content

Commit

Permalink
Merge pull request #6 from gsilvamartin/feature/streamenhance
Browse files Browse the repository at this point in the history
Add callback for stream
  • Loading branch information
gsilvamartin authored Feb 28, 2024
2 parents 687f836 + dc519b7 commit f18e9d5
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 30 deletions.
6 changes: 3 additions & 3 deletions src/DotnetGeminiSDK/Client/GeminiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public GeminiClient(GoogleGeminiConfig config, IApiRequester apiRequester)
/// <exception cref="Exception"></exception>
public Task StreamTextPrompt(
string message,
Action<GeminiMessageResponse> callback,
Action<string> callback,
GenerationConfig? generationConfig = null,
SafetySetting? safetySetting = null)
{
Expand All @@ -118,7 +118,7 @@ public Task StreamTextPrompt(
var promptUrl = $"{_config.TextBaseUrl}:streamGenerateContent?key={_config.ApiKey}";
var request = BuildGeminiRequest(message, generationConfig, safetySetting);

return _apiRequester.PostStream<GeminiMessageResponse>(promptUrl, request, callback);
return _apiRequester.PostStream(promptUrl, request, callback);
}
catch (Exception e)
{
Expand All @@ -143,7 +143,7 @@ public Task StreamTextPrompt(
/// <exception cref="Exception"></exception>
public Task StreamTextPrompt(
List<Content> messages,
Action<GeminiMessageResponse> callback,
Action<string> callback,
GenerationConfig? generationConfig = null,
SafetySetting? safetySetting = null)
{
Expand Down
4 changes: 2 additions & 2 deletions src/DotnetGeminiSDK/Client/Interfaces/IGeminiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ public interface IGeminiClient

Task StreamTextPrompt(
string message,
Action<GeminiMessageResponse> callback,
Action<string?> callback,
GenerationConfig? generationConfig = null,
SafetySetting? safetySetting = null
);

Task StreamTextPrompt(
List<Content> messages,
Action<GeminiMessageResponse> callback,
Action<string?> callback,
GenerationConfig? generationConfig = null,
SafetySetting? safetySetting = null
);
Expand Down
1 change: 0 additions & 1 deletion src/DotnetGeminiSDK/DotnetGeminiSDK.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,5 @@
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="8.0.0"/>
<PackageReference Include="Newtonsoft.Json" Version="13.0.3"/>
<PackageReference Include="System.Reactive" Version="6.0.0"/>
</ItemGroup>
</Project>
31 changes: 8 additions & 23 deletions src/DotnetGeminiSDK/Requester/ApiRequester.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
using System;
using System.IO;
using System.Net.Http;
using System.Reactive.Linq;
using System.Text;
using System.Threading.Tasks;
using DotnetGeminiSDK.Requester.Interfaces;
Expand Down Expand Up @@ -53,22 +51,23 @@ public async Task<T> PostAsync<T>(string url, object data)
/// <param name="url">Url to be requested</param>
/// <param name="data">Data containing body to send</param>
/// <param name="callback"> A callback to be called when the response is received</param>
/// <typeparam name="T">Return type of method</typeparam>
/// <returns>Observable post stream result</returns>
public async Task PostStream<T>(string url, object data, Action<T> callback)
public async Task PostStream(string url, object data, Action<string> callback)
{
var content = new StringContent(JsonConvert.SerializeObject(data), Encoding.UTF8, "application/json");

using var response = await _httpClient.PostAsync(url, content);
using var request = new HttpRequestMessage(HttpMethod.Post, url) { Content = content };
using var response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);

if (response.IsSuccessStatusCode)
{
int bytesRead;
var buffer = new byte[8192];
using var responseStream = await response.Content.ReadAsStreamAsync();
using var reader = new StreamReader(responseStream);

while (!reader.EndOfStream)
while ((bytesRead = await responseStream.ReadAsync(buffer, 0, buffer.Length)) > 0)
{
callback(await HandleResponse<T>(reader));
var chunk = Encoding.UTF8.GetString(buffer, 0, bytesRead);
callback(chunk);
}
}
else
Expand Down Expand Up @@ -122,19 +121,5 @@ private static async Task<T> HandleResponse<T>(HttpResponseMessage response)
return JsonConvert.DeserializeObject<T>(content) ??
throw new Exception("Cannot deserialize response from API");
}

/// <summary>
/// Handle the response from the API, deserializing the content from a stream
/// </summary>
/// <param name="streamReader">The stream reader</param>
/// <typeparam name="T">Return type of method</typeparam>
/// <returns></returns>
private static async Task<T> HandleResponse<T>(TextReader streamReader)
{
var content = await streamReader.ReadLineAsync();

return JsonConvert.DeserializeObject<T>(content) ??
throw new Exception("Cannot deserialize response from API");
}
}
}
2 changes: 1 addition & 1 deletion src/DotnetGeminiSDK/Requester/Interfaces/IApiRequester.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public interface IApiRequester
{
Task<T> GetAsync<T>(string url);
Task<T> PostAsync<T>(string url, object data);
Task PostStream<T>(string url, object data, Action<T> callback);
Task PostStream(string url, object data, Action<string> callback);
Task<T> PutAsync<T>(string url, object data);
Task<T> DeleteAsync<T>(string url);
}
Expand Down

0 comments on commit f18e9d5

Please sign in to comment.