fix(ui/security): Enforce idempotent AI fetching, secure auth handler, and memory leak guards #45

Merged
mjasin merged 4 commits from fix/idempotent-ai-fetching into develop 2026-05-20 17:27:40 +00:00
6 changed files with 197 additions and 16 deletions
@@ -31,7 +31,7 @@ public class KnowledgeService : IKnowledgeService
private readonly Tokenizer _tokenizer; private readonly Tokenizer _tokenizer;
private readonly ILogger<KnowledgeService> _logger; private readonly ILogger<KnowledgeService> _logger;
private const string PromptVersion = "1.3"; private const string PromptVersion = "1.3";
private static readonly ConcurrentDictionary<string, Task<Result<KnowledgePacket>>> _activeRequests = new(); private static readonly ConcurrentDictionary<string, Lazy<Task<Result<KnowledgePacket>>>> _activeRequests = new();
public KnowledgeService( public KnowledgeService(
IChatClient chatClient, IChatClient chatClient,
@@ -100,10 +100,35 @@ public class KnowledgeService : IKnowledgeService
// Deduplicate concurrent active requests for the exact same hash // Deduplicate concurrent active requests for the exact same hash
var requestKey = $"{tenantId}:{hash}:{traceType}"; var requestKey = $"{tenantId}:{hash}:{traceType}";
var task = _activeRequests.GetOrAdd(requestKey, _ =>
ExecuteAiRequestAndCacheAsync(normalizedText, tenantId, systemPrompt, traceType, ebookId, hash));
return await task; var lazyTask = _activeRequests.GetOrAdd(requestKey, k =>
new Lazy<Task<Result<KnowledgePacket>>>(
() => ExecuteAiRequestAndCacheAsync(normalizedText, tenantId, systemPrompt, traceType, ebookId, hash),
System.Threading.LazyThreadSafetyMode.ExecutionAndPublication
));
try
{
var result = await lazyTask.Value;
// If the AI call returned a failure, remove it from the active dictionary
// so subsequent retries have a chance to request the AI again.
if (result.IsFailed)
{
_activeRequests.TryRemove(requestKey, out _);
}
return result;
}
catch (Exception)
{
_activeRequests.TryRemove(requestKey, out _);
throw;
}
finally
{
_activeRequests.TryRemove(requestKey, out _);
}
} }
private async Task<Result<KnowledgePacket>> ExecuteAiRequestAndCacheAsync( private async Task<Result<KnowledgePacket>> ExecuteAiRequestAndCacheAsync(
@@ -51,9 +51,13 @@
private string _lastFetchedBlockId = string.Empty; private string _lastFetchedBlockId = string.Empty;
private KnowledgePacket? _packet; private KnowledgePacket? _packet;
private CancellationTokenSource? _streamCts; private CancellationTokenSource? _streamCts;
private bool _isInteractive;
protected override async Task OnParametersSetAsync() protected override async Task OnParametersSetAsync()
{ {
if (!_isInteractive)
return;
// Only re-fetch when the block context actually changes // Only re-fetch when the block context actually changes
if (string.IsNullOrEmpty(ContextBlockId) || ContextBlockId == _lastFetchedBlockId) if (string.IsNullOrEmpty(ContextBlockId) || ContextBlockId == _lastFetchedBlockId)
return; return;
@@ -62,6 +66,19 @@
await FetchAndStreamAsync(); await FetchAndStreamAsync();
} }
protected override async Task OnAfterRenderAsync(bool firstRender)
{
if (firstRender)
{
_isInteractive = true;
if (!string.IsNullOrEmpty(ContextBlockId))
{
_lastFetchedBlockId = ContextBlockId;
await FetchAndStreamAsync();
}
}
}
private async Task FetchAndStreamAsync() private async Task FetchAndStreamAsync()
{ {
// Cancel any in-progress stream // Cancel any in-progress stream
@@ -26,15 +26,40 @@
private GroundednessResult? _result; private GroundednessResult? _result;
private bool _isChecking; private bool _isChecking;
private bool _isInteractive;
private string _previousAnswer = string.Empty;
private string _previousContext = string.Empty;
protected override void OnParametersSet()
{
if (Answer != _previousAnswer || Context != _previousContext)
{
_result = null;
_previousAnswer = Answer;
_previousContext = Context;
}
}
protected override async Task OnParametersSetAsync() protected override async Task OnParametersSetAsync()
{ {
if (!string.IsNullOrEmpty(Answer) && !string.IsNullOrEmpty(Context) && _result == null) if (_isInteractive && !string.IsNullOrEmpty(Answer) && !string.IsNullOrEmpty(Context) && _result == null)
{ {
await RunCheck(); await RunCheck();
} }
} }
protected override async Task OnAfterRenderAsync(bool firstRender)
{
if (firstRender)
{
_isInteractive = true;
if (!string.IsNullOrEmpty(Answer) && !string.IsNullOrEmpty(Context) && _result == null)
{
await RunCheck();
}
}
}
private async Task RunCheck() private async Task RunCheck()
{ {
_isChecking = true; _isChecking = true;
+3 -2
View File
@@ -27,11 +27,10 @@
private IJSObjectReference? _keydownHandler; private IJSObjectReference? _keydownHandler;
private DotNetObjectReference<Home>? _dotNetRef; private DotNetObjectReference<Home>? _dotNetRef;
protected override async Task OnInitializedAsync() protected override void OnInitialized()
{ {
QuizState.OnQuizRequested += HandleQuizRequestedAsync; QuizState.OnQuizRequested += HandleQuizRequestedAsync;
FocusMode.OnFocusModeChanged += HandleUpdate; FocusMode.OnFocusModeChanged += HandleUpdate;
await FocusMode.InitializeAsync();
} }
protected override async Task OnParametersSetAsync() protected override async Task OnParametersSetAsync()
@@ -65,11 +64,13 @@
{ {
if (firstRender) if (firstRender)
{ {
await FocusMode.InitializeAsync();
try { try {
_interopModule = await JS.InvokeAsync<IJSObjectReference>("import", "./_content/NexusReader.UI.Shared/js/focusInterop.js"); _interopModule = await JS.InvokeAsync<IJSObjectReference>("import", "./_content/NexusReader.UI.Shared/js/focusInterop.js");
_dotNetRef = DotNetObjectReference.Create(this); _dotNetRef = DotNetObjectReference.Create(this);
_keydownHandler = await _interopModule.InvokeAsync<IJSObjectReference>("attachKeyboardListener", _dotNetRef); _keydownHandler = await _interopModule.InvokeAsync<IJSObjectReference>("attachKeyboardListener", _dotNetRef);
} catch { } /* ignored dynamically */ } catch { } /* ignored dynamically */
StateHasChanged();
} }
} }
@@ -508,9 +508,12 @@
private bool _isLoading = true; private bool _isLoading = true;
private List<LastReadBookDto>? _books; private List<LastReadBookDto>? _books;
protected override async Task OnInitializedAsync() protected override async Task OnAfterRenderAsync(bool firstRender)
{ {
await LoadBooksAsync(); if (firstRender)
{
await LoadBooksAsync();
}
} }
private async Task LoadBooksAsync() private async Task LoadBooksAsync()
@@ -1,31 +1,141 @@
using System.Net.Http.Headers; using System.Net.Http.Headers;
using System.Threading;
using Microsoft.AspNetCore.Components;
using Microsoft.AspNetCore.Components.WebAssembly.Http; using Microsoft.AspNetCore.Components.WebAssembly.Http;
using Microsoft.Extensions.DependencyInjection;
using NexusReader.Application.Abstractions.Services; using NexusReader.Application.Abstractions.Services;
namespace NexusReader.Web.Client.Handlers; namespace NexusReader.Web.Client.Handlers;
/// <summary>
/// A secure HTTP message delegating handler that automatically appends JWT tokens
/// to trusted origin requests and transparently refreshes expired tokens in a thread-safe manner.
/// </summary>
public class AuthenticationHeaderHandler : DelegatingHandler public class AuthenticationHeaderHandler : DelegatingHandler
{ {
private readonly INativeStorageService _storageService; private readonly INativeStorageService _storageService;
private readonly IServiceProvider _serviceProvider;
private const string TokenKey = "nexus_auth_token"; private const string TokenKey = "nexus_auth_token";
private static readonly SemaphoreSlim _refreshSemaphore = new(1, 1);
public AuthenticationHeaderHandler(INativeStorageService storageService) public AuthenticationHeaderHandler(INativeStorageService storageService, IServiceProvider serviceProvider)
{ {
_storageService = storageService; _storageService = storageService;
_serviceProvider = serviceProvider;
} }
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{ {
// Ensure cookies are sent (needed for InteractiveAuto SSR synchronization) // Force browser to forward credentials (cookies) for SSR hydration sync
request.SetBrowserRequestCredentials(BrowserRequestCredentials.Include); request.SetBrowserRequestCredentials(BrowserRequestCredentials.Include);
var tokenResult = await _storageService.GetSecureString(TokenKey); var path = request.RequestUri?.AbsolutePath ?? "";
bool isAuthEndpoint = path.Contains("identity/login") ||
path.Contains("identity/register") ||
path.Contains("identity/refresh");
if (tokenResult.IsSuccess && !string.IsNullOrEmpty(tokenResult.Value)) // SECURITY FIX (CWE-200): Ensure we only append JWT tokens to local or trusted base origin requests
var navigationManager = _serviceProvider.GetRequiredService<NavigationManager>();
bool isTrustedHost = request.RequestUri != null &&
(!request.RequestUri.IsAbsoluteUri ||
request.RequestUri.ToString().StartsWith(navigationManager.BaseUri, StringComparison.OrdinalIgnoreCase));
string? originalToken = null;
if (!isAuthEndpoint && isTrustedHost)
{ {
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", tokenResult.Value); var tokenResult = await _storageService.GetSecureString(TokenKey);
if (tokenResult.IsSuccess && !string.IsNullOrEmpty(tokenResult.Value))
{
originalToken = tokenResult.Value;
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", originalToken);
}
} }
return await base.SendAsync(request, cancellationToken); var response = await base.SendAsync(request, cancellationToken);
// Transparent JWT Auto-Refresh on 401 Unauthorized
if (response.StatusCode == System.Net.HttpStatusCode.Unauthorized && !isAuthEndpoint)
{
await _refreshSemaphore.WaitAsync(cancellationToken);
try
{
// Re-read token to verify if another concurrent request already refreshed it
var tokenResult = await _storageService.GetSecureString(TokenKey);
var currentToken = tokenResult.IsSuccess ? tokenResult.Value : null;
bool refreshed = false;
if (!string.IsNullOrEmpty(currentToken) && currentToken != originalToken)
{
refreshed = true;
}
else
{
// SECURITY FIX (CWE-400): Resolve scoped services within an explicit using scope to prevent memory leaks
using var scope = _serviceProvider.CreateScope();
var identityService = scope.ServiceProvider.GetRequiredService<IIdentityService>();
var refreshResult = await identityService.RefreshTokenAsync();
if (refreshResult.IsSuccess)
{
var newTokenResult = await _storageService.GetSecureString(TokenKey);
currentToken = newTokenResult.IsSuccess ? newTokenResult.Value : null;
refreshed = !string.IsNullOrEmpty(currentToken);
}
else
{
await identityService.LogoutAsync();
}
}
if (refreshed && !string.IsNullOrEmpty(currentToken))
{
var newRequest = await CloneHttpRequestMessageAsync(request);
newRequest.Headers.Authorization = new AuthenticationHeaderValue("Bearer", currentToken);
return await base.SendAsync(newRequest, cancellationToken);
}
}
catch (Exception ex)
{
// Write standard security audit safe debug log
Console.WriteLine($"[AuthHeaderHandler] Automated token renewal failed: {ex.Message}");
}
finally
{
_refreshSemaphore.Release();
}
}
return response;
}
private async Task<HttpRequestMessage> CloneHttpRequestMessageAsync(HttpRequestMessage req)
{
var clone = new HttpRequestMessage(req.Method, req.RequestUri)
{
Version = req.Version
};
if (req.Content != null)
{
var ms = new System.IO.MemoryStream();
await req.Content.CopyToAsync(ms);
ms.Position = 0;
clone.Content = new StreamContent(ms);
foreach (var h in req.Content.Headers)
{
clone.Content.Headers.TryAddWithoutValidation(h.Key, h.Value);
}
}
foreach (var h in req.Headers)
{
clone.Headers.TryAddWithoutValidation(h.Key, h.Value);
}
clone.SetBrowserRequestCredentials(BrowserRequestCredentials.Include);
return clone;
} }
} }