You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and dots ('.'), can be up to 35 characters long. Letters must be lowercase.
216 lines
8.1 KiB
216 lines
8.1 KiB
from drf_yasg.utils import swagger_auto_schema |
|
from drf_yasg import openapi |
|
|
|
from rest_framework_simplejwt.views import TokenObtainPairView |
|
from rest_framework_simplejwt.tokens import RefreshToken |
|
from rest_framework.response import Response |
|
from rest_framework.permissions import AllowAny, IsAuthenticated |
|
from rest_framework import status |
|
from rest_framework.decorators import action, api_view |
|
from adrf import viewsets |
|
from django.views.decorators.http import require_POST |
|
from django.utils import timezone |
|
from django.contrib.auth import get_user_model |
|
from .models import MasterKey, Conversation |
|
from django.http import StreamingHttpResponse |
|
from .models import Message |
|
from .serializers import ConversationSerializer, MessageSerializer |
|
|
|
from .utils import add_swagger_summaries |
|
from django.views import View |
|
import ollama |
|
from asgiref.sync import sync_to_async |
|
from django.utils.decorators import method_decorator |
|
import asyncio, base64, json |
|
from django.views.decorators.csrf import csrf_exempt |
|
|
|
User = get_user_model() |
|
|
|
@add_swagger_summaries |
|
class ConversationView(viewsets.ModelViewSet): |
|
|
|
queryset = Conversation.objects.all() |
|
serializer_class = ConversationSerializer |
|
permission_classes = [AllowAny] # JWT token auth required |
|
|
|
def get_queryset(self): |
|
queryset = Conversation.objects.all() |
|
return queryset |
|
|
|
def perform_create(self, serializer): |
|
"""Associate new product with current user""" |
|
serializer.save(user=self.request.user) |
|
|
|
@swagger_auto_schema( |
|
method='get', |
|
operation_description="Get all the messages of the conversation", |
|
operation_summary="Get all messages", |
|
responses={ |
|
200: openapi.Response('List of messages', MessageSerializer(many=True)), |
|
400: 'Bad Request', |
|
404: 'Conversation not found' |
|
}, |
|
manual_parameters=[ |
|
openapi.Parameter( |
|
'category', |
|
openapi.IN_QUERY, |
|
description="Filter featured items by category", |
|
type=openapi.TYPE_STRING |
|
) |
|
] |
|
) |
|
@action(detail=True, methods=['get']) |
|
async def contents(self, request, pk=None): |
|
conversation = self.get_object() |
|
messages = conversation.messages.all() |
|
return Response(data=list(messages.values())) |
|
|
|
@swagger_auto_schema( |
|
operation_description="Discutes with the ai", |
|
operation_summary="Make a prompt", |
|
request_body=openapi.Schema( |
|
type=openapi.TYPE_OBJECT, |
|
properties={ |
|
'content': openapi.Schema(type=openapi.TYPE_STRING, description='Contents of the message'), |
|
}, |
|
required=['content'] |
|
), |
|
responses={ |
|
200: openapi.Response( |
|
description="Text stream response", |
|
schema=openapi.Schema(type=openapi.TYPE_STRING, format='binary') |
|
), |
|
400: 'Bad Request', |
|
}, |
|
tags=['conversations'] |
|
) |
|
@action(detail=True, methods=['post']) |
|
def prompt(self, request, pk=None): |
|
conversation = self.get_object() |
|
messages = []#{ |
|
# "role": "system", |
|
# "content": """ |
|
# You must strictly refuse to engage with ANY of the following categories: |
|
# 1. Violence or harm (even fictional or hypothetical scenarios) |
|
# 2. ANY explicit, suggestive, or romantic content |
|
# 3. Controversial political topics |
|
# 4. ANY content that could potentially be illegal |
|
# 5. Medical, legal, or financial advice |
|
# 6. Personal information or privacy violations |
|
# 7. Anything that could be remotely offensive to anyone |
|
|
|
# If you detect such content, immediately respond with: |
|
# "I cannot assist with that request as it appears to be inappropriate. I'm designed to be helpful, but within strict ethical boundaries. Is there something else I can help you with?" |
|
# """ |
|
# }] |
|
for message in conversation.messages.all(): |
|
if message: |
|
messages.append({ |
|
"role": message.role, |
|
"content": message.content |
|
}) |
|
messages.append({ |
|
"role": "user", |
|
"content": request.data.get('content', '') |
|
}) |
|
Message(role="user", content=request.data.get('content', '') or "", conversation=conversation).save() |
|
ai_message = Message(role="assistant", content="", conversation=conversation) |
|
|
|
@sync_to_async |
|
def save_message(): |
|
ai_message.save() |
|
|
|
async def chat_event_stream(): |
|
message = "" |
|
try: |
|
stream = ollama.chat(model="gemma3:12b", messages=messages, stream=True) |
|
for chunk in stream: |
|
message = chunk['message']['content'] |
|
#print(message, base64.b64encode(message.encode("utf-8"))) |
|
ai_message.content += message |
|
yield f"{message}" |
|
finally: |
|
await save_message() |
|
response = StreamingHttpResponse(chat_event_stream(), content_type='text/event-stream') |
|
response['Cache-Control'] = 'no-cache' |
|
response['X-Accel-Buffering'] = 'no' |
|
return response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MasterKeyTokenObtainView(TokenObtainPairView): |
|
permission_classes = [AllowAny] |
|
|
|
@swagger_auto_schema( |
|
operation_description="Creates a token for a user (creates the user if doesn't exists)", |
|
request_body=openapi.Schema( |
|
type=openapi.TYPE_OBJECT, |
|
properties={ |
|
'username': openapi.Schema(type=openapi.TYPE_STRING, description='Identifier of the user'), |
|
'master_key': openapi.Schema(type=openapi.TYPE_STRING, description='Key of the authorizied application'), |
|
}, |
|
required=['username', 'master_key'] |
|
), |
|
responses={ |
|
200: openapi.Response('API tokens of the user', openapi.Schema( |
|
type=openapi.TYPE_OBJECT, |
|
properties={ |
|
'access': openapi.Schema(type=openapi.TYPE_STRING, description='API access token of the user'), |
|
'refresh': openapi.Schema(type=openapi.TYPE_STRING, description='API refresh token of the user'), |
|
}, |
|
required=['access', 'access'] |
|
)), |
|
400: 'Bad Request', |
|
401: 'Bad master key' |
|
} |
|
) |
|
def post(self, request, *args, **kwargs): |
|
# Get the provided master key from the request |
|
master_key_value = request.data.get('master_key') |
|
|
|
if not master_key_value: |
|
return Response( |
|
{'error': 'Master key is required'}, |
|
status=status.HTTP_400_BAD_REQUEST |
|
) |
|
|
|
# Verify the master key locally |
|
try: |
|
master_key = MasterKey.objects.get(key_value=master_key_value, is_active=True) |
|
except MasterKey.DoesNotExist: |
|
return Response( |
|
{'error': 'Invalid or inactive master key'}, |
|
status=status.HTTP_401_UNAUTHORIZED |
|
) |
|
|
|
# Update last used timestamp |
|
master_key.last_used = timezone.now() |
|
master_key.save(update_fields=['last_used']) |
|
|
|
# Get user identifier from request or use a default |
|
user_identifier = request.data.get('username', f'service_user_{master_key.key_id}') |
|
|
|
# Get or create a user associated with this master key |
|
user, created = User.objects.get_or_create( |
|
username=user_identifier, |
|
defaults={ |
|
'is_active': True |
|
} |
|
) |
|
|
|
# Generate tokens manually |
|
refresh = RefreshToken.for_user(user) |
|
|
|
# Add custom claims from the master key |
|
refresh['key_id'] = str(master_key.key_id) |
|
refresh['permissions'] = master_key.permissions |
|
|
|
return Response({ |
|
'refresh': str(refresh), |
|
'access': str(refresh.access_token), |
|
}) |