Added role agnostic training files

This commit is contained in:
Viswamedha Nalabotu 2026-03-15 22:19:12 +00:00
parent 6ccb7822c9
commit a9ba16c76d
11 changed files with 171 additions and 67 deletions

View file

@ -5,22 +5,22 @@ from apps.knowledge.models import RoleRagDocument, TrainingFile
@admin.register(TrainingFile)
class TrainingFileAdmin(admin.ModelAdmin):
list_display = ('file_name', 'role', 'status', 'is_processed', 'uploaded_by', 'created_at')
list_filter = ('status', 'is_processed', 'role__organization', 'created_at')
search_fields = ('file_name', 'role__name', 'uploaded_by__email_address')
raw_id_fields = ('role', 'uploaded_by')
list_display = ('file_name', 'organization', 'role', 'status', 'is_processed', 'uploaded_by', 'created_at')
list_filter = ('status', 'is_processed', 'organization', 'created_at')
search_fields = ('file_name', 'organization__name', 'role__name', 'uploaded_by__email_address')
raw_id_fields = ('organization', 'role', 'uploaded_by')
readonly_fields = ('uuid', 'file_size', 'file_type', 'created_at', 'updated_at')
ordering = ('-created_at',)
@admin.register(RoleRagDocument)
class RoleRagDocumentAdmin(admin.ModelAdmin):
list_display = ('role', 'chunk_index', 'training_file', 'is_active', 'created_at')
list_filter = ('is_active', 'role__organization', 'created_at')
search_fields = ('content', 'role__name', 'training_file__file_name')
raw_id_fields = ('role', 'training_file')
list_display = ('organization', 'role', 'chunk_index', 'training_file', 'is_active', 'created_at')
list_filter = ('is_active', 'organization', 'created_at')
search_fields = ('content', 'organization__name', 'role__name', 'training_file__file_name')
raw_id_fields = ('organization', 'role', 'training_file')
readonly_fields = ('uuid', 'content_hash', 'display_embedding', 'created_at', 'updated_at')
ordering = ('role', 'chunk_index')
ordering = ('organization', 'role', 'chunk_index')
def get_fields(self, request, obj=None):
fields = super().get_fields(request, obj)

View file

@ -30,7 +30,8 @@ class Migration(migrations.Migration):
('description', models.TextField(blank=True, default='')),
('status', models.CharField(choices=[('ingesting', 'Ingesting'), ('chunked', 'Chunked'), ('embedded', 'Embedded'), ('failed', 'Failed')], default='ingesting', max_length=20)),
('is_processed', models.BooleanField(default=False)),
('role', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='training_files', to='accounts.role')),
('organization', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='training_files', to='accounts.organization')),
('role', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='training_files', to='accounts.role')),
('uploaded_by', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='uploaded_training_files', to=settings.AUTH_USER_MODEL)),
],
options={
@ -52,7 +53,8 @@ class Migration(migrations.Migration):
('metadata', models.JSONField(blank=True, default=dict)),
('chunk_index', models.IntegerField(default=0)),
('is_active', models.BooleanField(default=True)),
('role', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='rag_documents', to='accounts.role')),
('organization', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='rag_documents', to='accounts.organization')),
('role', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='rag_documents', to='accounts.role')),
('training_file', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='chunks', to='knowledge.trainingfile')),
],
options={

View file

@ -2,14 +2,14 @@ import os
from django.conf import settings
from django.db import transaction
from django.db.models import CASCADE, BooleanField, CharField, FileField, ForeignKey, IntegerField, JSONField, Model, TextField
from django.db.models import CASCADE, SET_NULL, BooleanField, CharField, FileField, ForeignKey, IntegerField, JSONField, Model, TextField
from django.db.models.signals import post_delete, post_save
from django.dispatch import receiver
from django.utils.translation import gettext_lazy as _
from pgvector.django import VectorField
from apps.accounts.mixins import IdentifierMixin, TimeStampMixin
from apps.accounts.models import Role, User
from apps.accounts.models import Organization, Role, User
class TrainingFile(IdentifierMixin, TimeStampMixin, Model):
STATUS_CHOICES = [
@ -19,7 +19,8 @@ class TrainingFile(IdentifierMixin, TimeStampMixin, Model):
('failed', 'Failed'),
]
role = ForeignKey(Role, on_delete=CASCADE, related_name="training_files")
organization = ForeignKey(Organization, on_delete=CASCADE, related_name="training_files")
role = ForeignKey(Role, on_delete=CASCADE, related_name="training_files", null=True, blank=True)
uploaded_by = ForeignKey(User, on_delete=CASCADE, related_name="uploaded_training_files")
file = FileField(upload_to='training_files/%Y/%m/%d/')
@ -37,11 +38,14 @@ class TrainingFile(IdentifierMixin, TimeStampMixin, Model):
ordering = ['-created_at']
def __str__(self) -> str:
return f"{self.file_name} ({self.role.name})"
if self.role_id:
return f"{self.file_name} ({self.role.name})"
return f"{self.file_name} ({self.organization.name} - Organization-wide)"
class RoleRagDocument(IdentifierMixin, TimeStampMixin, Model):
role = ForeignKey(Role, on_delete=CASCADE, related_name='rag_documents')
organization = ForeignKey(Organization, on_delete=CASCADE, related_name='rag_documents')
role = ForeignKey(Role, on_delete=SET_NULL, related_name='rag_documents', null=True, blank=True)
training_file = ForeignKey(TrainingFile, on_delete=CASCADE, related_name='chunks', null=True, blank=True)
content = TextField()
@ -58,7 +62,9 @@ class RoleRagDocument(IdentifierMixin, TimeStampMixin, Model):
verbose_name_plural = _("Role RAG Documents")
def __str__(self) -> str:
return f"{self.role.name} - Chunk {self.chunk_index}"
if self.role_id:
return f"{self.role.name} - Chunk {self.chunk_index}"
return f"{self.organization.name} (Organization-wide) - Chunk {self.chunk_index}"
@receiver(post_delete, sender=TrainingFile)
def delete_physical_file(sender, instance, **kwargs):

View file

@ -1,32 +1,37 @@
from rest_framework.serializers import ModelSerializer, SerializerMethodField
from apps.accounts.serializers import RoleSerializer, UserSerializer
from apps.accounts.serializers import OrganizationSerializer, RoleSerializer, UserSerializer
from apps.knowledge.models import RoleRagDocument, TrainingFile
class TrainingFileSerializer(ModelSerializer):
uploaded_by = UserSerializer(read_only=True)
organization = OrganizationSerializer(read_only=True)
role = RoleSerializer(read_only=True)
file_url = SerializerMethodField()
scope = SerializerMethodField()
class Meta:
model = TrainingFile
fields = [
'id', 'uuid', 'role', 'uploaded_by', 'file', 'file_url',
'id', 'uuid', 'organization', 'role', 'scope', 'uploaded_by', 'file', 'file_url',
'file_name', 'file_size', 'file_type', 'description',
'status', 'is_processed', 'created_at', 'updated_at'
]
read_only_fields = [
'id', 'uuid', 'uploaded_by', 'file_size', 'file_type',
'status', 'is_processed', 'created_at', 'updated_at',
'role'
'organization', 'role', 'scope'
]
def get_file_url(self, obj: TrainingFile) -> str:
def get_file_url(self, obj: TrainingFile):
request = self.context.get('request')
if obj.file and request:
return request.build_absolute_uri(obj.file.url)
return obj.file.url if obj.file else None
def get_scope(self, obj: TrainingFile) -> str:
return 'role' if obj.role_id else 'organization'
class RoleRagDocumentSerializer(ModelSerializer):
training_file_name = SerializerMethodField()

View file

@ -77,13 +77,18 @@ def ingest_training_file_task(self, file_uuid):
for chunk_text, embedding in zip(chunks, embeddings):
all_documents.append(RoleRagDocument(
organization=file_obj.organization,
role=file_obj.role,
training_file=file_obj,
content=chunk_text,
content_hash=hashlib.sha256(chunk_text.encode('utf-8')).hexdigest(),
embedding=embedding,
chunk_index=chunk_counter,
metadata={"source": file_obj.file_name}
metadata={
"source": file_obj.file_name,
"file_name": file_obj.file_name,
"scope": "role" if file_obj.role_id else "organization",
},
))
chunk_counter += 1

View file

@ -44,6 +44,7 @@ class KnowledgeApiTests(TestCase):
self.role = Role.objects.create(name='Researcher', organization=self.org)
self.training_file = TrainingFile.objects.create(
organization=self.org,
role=self.role,
uploaded_by=self.owner,
file=SimpleUploadedFile('doc.txt', b'content', content_type='text/plain'),
@ -52,6 +53,7 @@ class KnowledgeApiTests(TestCase):
file_type='text/plain',
)
self.rag_doc = RoleRagDocument.objects.create(
organization=self.org,
role=self.role,
training_file=self.training_file,
content='chunk body',
@ -136,7 +138,7 @@ class KnowledgeApiTests(TestCase):
'file_name': 'new.txt',
})
self.assertEqual(response.status_code, HTTP_400_BAD_REQUEST)
self.assertIn('role_uuid', response.json())
self.assertIn('organization_uuid', response.json())
def test_training_file_create_by_owner_succeeds(self):
self.client.force_authenticate(self.owner)
@ -148,6 +150,17 @@ class KnowledgeApiTests(TestCase):
})
self.assertEqual(response.status_code, HTTP_201_CREATED)
def test_training_file_create_org_wide_by_owner_succeeds(self):
self.client.force_authenticate(self.owner)
uploaded = SimpleUploadedFile('org-wide.txt', b'org policy', content_type='text/plain')
response = self.client.post('/api/training-file/', {
'organization_uuid': str(self.org.uuid),
'file': uploaded,
'file_name': 'org-wide.txt',
})
self.assertEqual(response.status_code, HTTP_201_CREATED)
self.assertIsNone(response.json().get('role'))
def test_training_file_destroy_forbidden_for_regular_member(self):
self.client.force_authenticate(self.member)
response = self.client.delete(f'/api/training-file/{self.training_file.uuid}/')

View file

@ -34,6 +34,7 @@ class KnowledgeModelTests(TestCase):
def test_training_file_fields_and_defaults(self):
uploaded = SimpleUploadedFile('training.txt', b'hello world', content_type='text/plain')
training_file = TrainingFile.objects.create(
organization=self.org,
role=self.role,
uploaded_by=self.user,
file=uploaded,
@ -44,6 +45,7 @@ class KnowledgeModelTests(TestCase):
)
self.assertEqual(training_file.role, self.role)
self.assertEqual(training_file.organization, self.org)
self.assertEqual(training_file.uploaded_by, self.user)
self.assertEqual(training_file.file_name, 'training.txt')
self.assertEqual(training_file.file_size, 11)
@ -62,6 +64,7 @@ class KnowledgeModelTests(TestCase):
def test_role_rag_document_fields_and_defaults(self):
uploaded = SimpleUploadedFile('base.txt', b'base', content_type='text/plain')
training_file = TrainingFile.objects.create(
organization=self.org,
role=self.role,
uploaded_by=self.user,
file=uploaded,
@ -70,6 +73,7 @@ class KnowledgeModelTests(TestCase):
file_type='text/plain',
)
document = RoleRagDocument.objects.create(
organization=self.org,
role=self.role,
training_file=training_file,
content='Chunk content',
@ -80,6 +84,7 @@ class KnowledgeModelTests(TestCase):
)
self.assertEqual(document.role, self.role)
self.assertEqual(document.organization, self.org)
self.assertEqual(document.training_file, training_file)
self.assertEqual(document.content, 'Chunk content')
self.assertEqual(document.content_hash, 'a' * 64)

View file

@ -4,7 +4,7 @@ from rest_framework.parsers import FormParser, MultiPartParser
from rest_framework.permissions import IsAuthenticated
from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet
from apps.accounts.models import Role
from apps.accounts.models import Organization, Role
from apps.accounts.permissions import can_manage_organization
from apps.knowledge.models import RoleRagDocument, TrainingFile
from apps.knowledge.serializers import RoleRagDocumentSerializer, TrainingFileSerializer
@ -19,35 +19,51 @@ class TrainingFileViewSet(ModelViewSet):
def get_queryset(self):
user = self.request.user
queryset = TrainingFile.objects.filter(
Q(role__organization__owner=user) |
Q(role__organization__members=user)
Q(organization__owner=user) |
Q(organization__members=user)
).distinct()
organization_uuid = self.request.query_params.get('organization_uuid')
if organization_uuid in (None, ''):
organization_uuid = self.request.data.get('organization_uuid')
if organization_uuid:
queryset = queryset.filter(role__organization__uuid=organization_uuid)
queryset = queryset.filter(organization__uuid=organization_uuid)
role_uuid = self.request.query_params.get('role_uuid')
if role_uuid in (None, ''):
role_uuid = self.request.data.get('role_uuid')
if role_uuid:
queryset = queryset.filter(role__uuid=role_uuid)
queryset = queryset.filter(Q(role__uuid=role_uuid) | Q(role__isnull=True))
return queryset
def perform_create(self, serializer):
role_uuid = self.request.data.get('role_uuid')
if not role_uuid:
raise ValidationError({'role_uuid': 'role_uuid is required.'})
organization_uuid = self.request.data.get('organization_uuid')
try:
role = Role.objects.get(uuid=role_uuid)
except Role.DoesNotExist:
raise NotFound('Role not found')
role = None
organization = None
if not can_manage_organization(self.request.user, role.organization):
if role_uuid:
try:
role = Role.objects.select_related('organization').get(uuid=role_uuid)
except Role.DoesNotExist:
raise NotFound('Role not found')
organization = role.organization
if organization_uuid and str(organization.uuid) != str(organization_uuid):
raise ValidationError({'organization_uuid': 'organization_uuid does not match role organization.'})
else:
if not organization_uuid:
raise ValidationError({'organization_uuid': 'organization_uuid is required when role_uuid is not provided.'})
try:
organization = Organization.objects.get(uuid=organization_uuid)
except Organization.DoesNotExist:
raise NotFound('Organization not found')
if not can_manage_organization(self.request.user, organization):
raise PermissionDenied('Permission denied')
uploaded_file = self.request.FILES.get('file')
@ -56,6 +72,7 @@ class TrainingFileViewSet(ModelViewSet):
serializer.save(
uploaded_by=self.request.user,
organization=organization,
role=role,
file_name=uploaded_file.name,
file_size=uploaded_file.size,
@ -66,8 +83,8 @@ class TrainingFileViewSet(ModelViewSet):
instance = self.get_object()
is_uploader = instance.uploaded_by == request.user
is_org_owner = instance.role.organization.owner == request.user
is_org_manager = bool(request.user.is_manager) and instance.role.organization.members.filter(id=request.user.id).exists()
is_org_owner = instance.organization.owner == request.user
is_org_manager = bool(request.user.is_manager) and instance.organization.members.filter(id=request.user.id).exists()
if not (is_uploader or is_org_owner or is_org_manager):
raise PermissionDenied('Permission denied')
@ -83,15 +100,15 @@ class RoleRagDocumentViewSet(ReadOnlyModelViewSet):
def get_queryset(self):
user = self.request.user
queryset = RoleRagDocument.objects.filter(
Q(role__organization__owner=user) |
Q(role__organization__members=user)
Q(organization__owner=user) |
Q(organization__members=user)
).distinct()
organization_uuid = self.request.query_params.get('organization_uuid')
if organization_uuid in (None, ''):
organization_uuid = self.request.data.get('organization_uuid')
if organization_uuid:
queryset = queryset.filter(role__organization__uuid=organization_uuid)
queryset = queryset.filter(organization__uuid=organization_uuid)
role_uuid = self.request.query_params.get('role_uuid')
if role_uuid in (None, ''):

View file

@ -4,8 +4,10 @@ import random
from channels.db import database_sync_to_async
from django.conf import settings
from django.db.models import Q
from pgvector.django import CosineDistance
from apps.accounts.models import Role
from apps.knowledge.models import RoleRagDocument
from apps.onboarding.models import OnboardingSession
@ -105,9 +107,17 @@ class MCPRouter:
@database_sync_to_async
def _search_knowledge_documents(self, role_uuid, query_vector):
role = Role.objects.select_related('organization').filter(uuid=role_uuid).first()
if role is None:
logger.warning('MCP search_knowledge_documents role not found: role_uuid=%s', role_uuid)
return []
docs = RoleRagDocument.objects.filter(
role__uuid=role_uuid,
organization=role.organization,
embedding__isnull=False,
is_active=True,
).filter(
Q(role__uuid=role_uuid) | Q(role__isnull=True),
).annotate(
distance=CosineDistance('embedding', query_vector)
).order_by('distance')[:5]
@ -115,7 +125,7 @@ class MCPRouter:
results = [
{
'content': d.content,
'source': d.metadata.get('file_name', 'Unknown Source'),
'source': d.metadata.get('file_name') or d.metadata.get('source', 'Unknown Source'),
'relevance': round(1 - d.distance, 4),
}
for d in docs

View file

@ -35,7 +35,9 @@ export interface InviteToken {
}
export interface TrainingFile {
uuid: string
role: Role
organization: Organization
role: Role | null
scope?: 'role' | 'organization'
uploaded_by: User
file: string
file_name: string

View file

@ -67,6 +67,12 @@ const inviteModalVisible = ref(false)
const newInviteUrl = ref('')
const editingDescription = ref(false)
const newDescription = ref('')
const ORGANIZATION_WIDE_SCOPE = '__organization_wide__'
const uploadRoleOptions = computed(() => [
{ label: 'Organization-wide (all roles)', value: ORGANIZATION_WIDE_SCOPE },
...Roles.value.map((role) => ({ label: role.name, value: role.uuid })),
])
const filteredMembers = computed(() => {
const query = memberSearch.value.trim().toLowerCase()
@ -151,6 +157,8 @@ const fetchTrainingFiles = async () => {
}
}
const getScopeLabel = (file: TrainingFile) => (file.role?.name ? file.role.name : 'Organization-wide')
const resetRoleWizard = () => {
roleWizardStep.value = 0
@ -200,7 +208,7 @@ const validateUploadFile = (file: File): boolean => {
}
const uploadTrainingFile = async (
roleUuid: string,
roleUuid: string | null,
file: File,
description: string,
): Promise<TrainingFile | null> => {
@ -208,7 +216,10 @@ const uploadTrainingFile = async (
formData.append('file', file)
formData.append('file_name', file.name)
formData.append('description', description)
formData.append('role_uuid', roleUuid)
formData.append('organization_uuid', organizationUuid)
if (roleUuid) {
formData.append('role_uuid', roleUuid)
}
try {
const response = await apiClient.post<TrainingFile>(API.knowledge.trainingFiles.list(), formData, {
@ -234,6 +245,10 @@ const uploadTrainingFile = async (
const getTrainingFilesByRole = (roleUuid: string): TrainingFile[] =>
trainingFiles.value.filter((file) => file.role?.uuid === roleUuid)
const organizationWideTrainingFiles = computed(() =>
trainingFiles.value.filter((file) => !file.role?.uuid),
)
const deleteTrainingFile = async (uuid: string, fileName: string) => {
Modal.confirm({
title: 'Delete File',
@ -289,9 +304,9 @@ const trainingFileColumns = [
customRender: ({ value }: { value: number }) => formatFileSize(value || 0),
},
{
title: 'Role',
title: 'Scope',
key: 'role',
customRender: ({ record }: { record: TrainingFile }) => record.role?.name || '-',
customRender: ({ record }: { record: TrainingFile }) => getScopeLabel(record),
},
{
title: 'Status',
@ -426,7 +441,7 @@ const uploadFileFromWizard = async () => {
}
const openUploadModal = (role?: Role) => {
uploadRoleUuid.value = role?.uuid || ''
uploadRoleUuid.value = role?.uuid || ORGANIZATION_WIDE_SCOPE
uploadSelectedFile.value = null
uploadFileDescription.value = ''
uploadModalVisible.value = true
@ -441,20 +456,18 @@ const handleUploadModalFileSelected = (file: File) => {
}
const handleUploadModalOk = async () => {
if (!uploadRoleUuid.value) {
message.error('Please select a role for this training file')
return
}
if (!uploadSelectedFile.value) {
message.error('Please select a file to upload')
return
}
const selectedRoleUuid =
uploadRoleUuid.value === ORGANIZATION_WIDE_SCOPE ? null : uploadRoleUuid.value
uploadingFile.value = true
try {
const uploaded = await uploadTrainingFile(
uploadRoleUuid.value,
selectedRoleUuid,
uploadSelectedFile.value,
uploadFileDescription.value,
)
@ -848,7 +861,29 @@ onMounted(async () => {
</List.Item>
</template>
</List>
<Typography.Paragraph v-else type="secondary">
<div v-if="organizationWideTrainingFiles.length > 0" class="role-files" style="margin-top: 1rem">
<Typography.Text strong>
Organization-wide training files (applies to all roles)
</Typography.Text>
<List
:data-source="organizationWideTrainingFiles"
size="small"
:bordered="false"
>
<template #renderItem="{ item: file }">
<List.Item>
<Space style="display: flex; justify-content: space-between; width: 100%">
<Typography.Text>{{ file.file_name }}</Typography.Text>
<Tag color="geekblue">Organization-wide</Tag>
</Space>
</List.Item>
</template>
</List>
</div>
<Typography.Paragraph
v-if="filteredRoles.length === 0 && organizationWideTrainingFiles.length === 0"
type="secondary"
>
{{ roleEmptyMessage }}
</Typography.Paragraph>
</div>
@ -913,7 +948,8 @@ onMounted(async () => {
<Typography.Paragraph type="secondary" style="margin-bottom: 0">
Upload optional training files for
<strong>{{ createdRoleForWizard?.name }}</strong>
. You can also do this later.
. You can also do this later. Use the main Upload Training File modal for
organization-wide files.
</Typography.Paragraph>
<Input.TextArea
@ -970,7 +1006,7 @@ onMounted(async () => {
title="Upload Training File"
ok-text="Upload"
cancel-text="Cancel"
:ok-button-props="{ loading: uploadingFile, disabled: !uploadRoleUuid || !uploadSelectedFile }"
:ok-button-props="{ loading: uploadingFile, disabled: !uploadSelectedFile }"
@ok="handleUploadModalOk"
@cancel="uploadModalVisible = false"
>
@ -982,13 +1018,16 @@ onMounted(async () => {
</Typography.Text>
<div>
<Typography.Text strong>Role</Typography.Text>
<Typography.Text strong>Scope</Typography.Text>
<Select
v-model:value="uploadRoleUuid"
placeholder="Select a role"
placeholder="Select training scope"
style="width: 100%"
:options="Roles.map((role) => ({ label: role.name, value: role.uuid }))"
:options="uploadRoleOptions"
/>
<Typography.Paragraph type="secondary" style="margin: 0.5rem 0 0">
Organization-wide files apply to every role in this organization.
</Typography.Paragraph>
</div>
<Input.TextArea