186 lines
7.7 KiB
Python
186 lines
7.7 KiB
Python
from unittest.mock import patch
|
|
|
|
from django.contrib.auth import get_user_model
|
|
from django.core.files.uploadedfile import SimpleUploadedFile
|
|
from django.db.models.signals import post_save
|
|
from django.test import TestCase
|
|
from rest_framework.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST, HTTP_403_FORBIDDEN
|
|
from rest_framework.test import APIClient
|
|
|
|
from apps.accounts.models import Organization, Role
|
|
from apps.knowledge.models import RoleRagDocument, TrainingFile, trigger_ingestion
|
|
|
|
User = get_user_model()
|
|
|
|
class KnowledgeApiTests(TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
post_save.disconnect(trigger_ingestion, sender=TrainingFile)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
post_save.connect(trigger_ingestion, sender=TrainingFile)
|
|
super().tearDownClass()
|
|
|
|
def setUp(self):
|
|
self.client: APIClient = APIClient()
|
|
self.owner = User.objects.create_user(
|
|
email_address='owner-k@example.com',
|
|
password='pass1234',
|
|
first_name='Owner',
|
|
last_name='K',
|
|
date_of_birth='1990-01-01',
|
|
is_manager=True,
|
|
)
|
|
self.member = User.objects.create_user(
|
|
email_address='member-k@example.com',
|
|
password='pass1234',
|
|
first_name='Member',
|
|
last_name='K',
|
|
date_of_birth='1992-02-02',
|
|
)
|
|
|
|
self.org = Organization.objects.create(name='Knowledge API Org', owner=self.owner)
|
|
self.org.members.add(self.owner, self.member)
|
|
self.role = Role.objects.create(name='Researcher', organization=self.org)
|
|
|
|
self.training_file = TrainingFile.objects.create(
|
|
organization=self.org,
|
|
role=self.role,
|
|
uploaded_by=self.owner,
|
|
file=SimpleUploadedFile('doc.txt', b'content', content_type='text/plain'),
|
|
file_name='doc.txt',
|
|
file_size=7,
|
|
file_type='text/plain',
|
|
)
|
|
self.rag_doc = RoleRagDocument.objects.create(
|
|
organization=self.org,
|
|
role=self.role,
|
|
training_file=self.training_file,
|
|
content='chunk body',
|
|
content_hash='b' * 64,
|
|
metadata={'k': 'v'},
|
|
chunk_index=0,
|
|
)
|
|
|
|
def test_training_file_list_path(self):
|
|
self.client.force_authenticate(self.member)
|
|
response = self.client.get('/api/training-file/')
|
|
self.assertEqual(response.status_code, HTTP_200_OK)
|
|
|
|
def test_training_file_create_path(self):
|
|
self.client.force_authenticate(self.member)
|
|
uploaded = SimpleUploadedFile('new.txt', b'new body', content_type='text/plain')
|
|
response = self.client.post('/api/training-file/', {
|
|
'role_uuid': str(self.role.uuid),
|
|
'description': 'new file',
|
|
'file': uploaded,
|
|
})
|
|
self.assertEqual(response.status_code, HTTP_400_BAD_REQUEST)
|
|
|
|
def test_training_file_retrieve_path(self):
|
|
self.client.force_authenticate(self.member)
|
|
response = self.client.get(f'/api/training-file/{self.training_file.uuid}/')
|
|
self.assertEqual(response.status_code, HTTP_200_OK)
|
|
|
|
def test_training_file_update_path(self):
|
|
self.client.force_authenticate(self.owner)
|
|
response = self.client.put(
|
|
f'/api/training-file/{self.training_file.uuid}/',
|
|
{
|
|
'description': 'updated desc',
|
|
'file': SimpleUploadedFile('replace.txt', b'updated', content_type='text/plain'),
|
|
},
|
|
)
|
|
self.assertEqual(response.status_code, HTTP_400_BAD_REQUEST)
|
|
|
|
def test_training_file_partial_update_path(self):
|
|
self.client.force_authenticate(self.owner)
|
|
response = self.client.patch(
|
|
f'/api/training-file/{self.training_file.uuid}/',
|
|
{'description': 'patched desc'},
|
|
format='multipart',
|
|
)
|
|
self.assertIn(response.status_code, (HTTP_200_OK, HTTP_400_BAD_REQUEST))
|
|
|
|
@patch('apps.knowledge.viewsets.update_agent_prompts_from_file_task')
|
|
def test_training_file_destroy_path(self, mock_task):
|
|
self.client.force_authenticate(self.owner)
|
|
response = self.client.delete(f'/api/training-file/{self.training_file.uuid}/')
|
|
self.assertEqual(response.status_code, HTTP_204_NO_CONTENT)
|
|
|
|
def test_role_rag_document_list_path(self):
|
|
self.client.force_authenticate(self.member)
|
|
response = self.client.get('/api/role-rag-document/')
|
|
self.assertEqual(response.status_code, HTTP_200_OK)
|
|
|
|
def test_role_rag_document_retrieve_path(self):
|
|
self.client.force_authenticate(self.member)
|
|
response = self.client.get(f'/api/role-rag-document/{self.rag_doc.uuid}/')
|
|
self.assertEqual(response.status_code, HTTP_200_OK)
|
|
|
|
def test_training_file_list_for_non_member_returns_empty(self):
|
|
outsider = User.objects.create_user(
|
|
email_address='outsider-k@example.com',
|
|
password='pass1234',
|
|
first_name='Out',
|
|
last_name='Sider',
|
|
date_of_birth='1994-04-04',
|
|
)
|
|
self.client.force_authenticate(outsider)
|
|
response = self.client.get('/api/training-file/')
|
|
self.assertEqual(response.status_code, HTTP_200_OK)
|
|
self.assertEqual(len(response.json()), 0)
|
|
|
|
def test_training_file_create_requires_role_uuid(self):
|
|
self.client.force_authenticate(self.owner)
|
|
uploaded = SimpleUploadedFile('new.txt', b'new body', content_type='text/plain')
|
|
response = self.client.post('/api/training-file/', {
|
|
'file': uploaded,
|
|
'file_name': 'new.txt',
|
|
})
|
|
self.assertEqual(response.status_code, HTTP_400_BAD_REQUEST)
|
|
self.assertIn('organization_uuid', response.json())
|
|
|
|
def test_training_file_create_by_owner_succeeds(self):
|
|
self.client.force_authenticate(self.owner)
|
|
uploaded = SimpleUploadedFile('owner-ok.txt', b'owner file', content_type='text/plain')
|
|
response = self.client.post('/api/training-file/', {
|
|
'role_uuid': str(self.role.uuid),
|
|
'file': uploaded,
|
|
'file_name': 'owner-ok.txt',
|
|
})
|
|
self.assertEqual(response.status_code, HTTP_201_CREATED)
|
|
|
|
def test_training_file_create_org_wide_by_owner_succeeds(self):
|
|
self.client.force_authenticate(self.owner)
|
|
uploaded = SimpleUploadedFile('org-wide.txt', b'org policy', content_type='text/plain')
|
|
response = self.client.post('/api/training-file/', {
|
|
'organization_uuid': str(self.org.uuid),
|
|
'file': uploaded,
|
|
'file_name': 'org-wide.txt',
|
|
})
|
|
self.assertEqual(response.status_code, HTTP_201_CREATED)
|
|
self.assertIsNone(response.json().get('role'))
|
|
|
|
def test_training_file_destroy_forbidden_for_regular_member(self):
|
|
self.client.force_authenticate(self.member)
|
|
response = self.client.delete(f'/api/training-file/{self.training_file.uuid}/')
|
|
self.assertEqual(response.status_code, HTTP_403_FORBIDDEN)
|
|
|
|
@patch('apps.knowledge.viewsets.update_agent_prompts_from_file_task')
|
|
def test_training_file_destroy_allowed_for_org_manager_member(self, mock_task):
|
|
manager_member = User.objects.create_user(
|
|
email_address='manager-member-k@example.com',
|
|
password='pass1234',
|
|
first_name='Manager',
|
|
last_name='Member',
|
|
date_of_birth='1995-05-05',
|
|
is_manager=True,
|
|
)
|
|
self.org.members.add(manager_member)
|
|
|
|
self.client.force_authenticate(manager_member)
|
|
response = self.client.delete(f'/api/training-file/{self.training_file.uuid}/')
|
|
self.assertEqual(response.status_code, HTTP_204_NO_CONTENT)
|