Skip to content

Commit 5783654

Browse files
Chat Sharing (#200)
* share table * index and deletions shared conversation * global update * global sharing * get by user global share * adding project sharing settings * user management and project and tags * get global share * PR comments * filter conversation tags by user * PR comments * PR comments
1 parent c996e22 commit 5783654

File tree

8 files changed

+356
-9
lines changed

8 files changed

+356
-9
lines changed

business_objects/user.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from datetime import datetime
22
from . import general, organization, team_member
3-
from .. import User, enums
3+
from .. import User, enums, Team, TeamMember, TeamResource
44
from ..session import session
55
from typing import List, Optional
66
from sqlalchemy import sql
@@ -52,6 +52,19 @@ def get_all(
5252
return query.all()
5353

5454

55+
def get_all_team_members_by_project(project_id: str) -> List[User]:
56+
query = (
57+
session.query(TeamMember)
58+
.join(Team, Team.id == TeamMember.team_id)
59+
.join(TeamResource, TeamResource.team_id == Team.id)
60+
.filter(TeamResource.resource_id == project_id)
61+
.filter(
62+
TeamResource.resource_type == enums.TeamResourceType.COGNITION_PROJECT.value
63+
)
64+
)
65+
return query.all()
66+
67+
5568
def get_count_assigned() -> int:
5669
return session.query(User.id).filter(User.organization_id != None).count()
5770

cognition_objects/conversation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@ def get(project_id: str, conversation_id: str) -> CognitionConversation:
2929
)
3030

3131

32+
def get_by_id(conversation_id: str) -> CognitionConversation:
33+
return (
34+
session.query(CognitionConversation)
35+
.filter(CognitionConversation.id == conversation_id)
36+
.first()
37+
)
38+
39+
3240
def exists(project_id: str, conversation_id: str) -> bool:
3341
return (
3442
session.query(CognitionConversation)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from operator import or_
2+
from typing import List, Optional, Dict, Any
3+
from ..business_objects import general
4+
from ..session import session
5+
from ..models import CognitionConversation, ConversationGlobalShare
6+
from submodules.model.util import sql_alchemy_to_dict
7+
8+
9+
def get(conversation_global_share_id: str) -> Optional[ConversationGlobalShare]:
10+
return (
11+
session.query(ConversationGlobalShare)
12+
.filter(ConversationGlobalShare.id == conversation_global_share_id)
13+
.first()
14+
)
15+
16+
17+
def get_by_conversation(conversation_id: str) -> Optional[ConversationGlobalShare]:
18+
return (
19+
session.query(ConversationGlobalShare)
20+
.filter(ConversationGlobalShare.conversation_id == conversation_id)
21+
.first()
22+
)
23+
24+
25+
def create(
26+
conversation_id: str,
27+
shared_by: str,
28+
with_commit: bool = True,
29+
) -> ConversationGlobalShare:
30+
global_share = ConversationGlobalShare(
31+
conversation_id=conversation_id, shared_by=shared_by
32+
)
33+
general.add(global_share, with_commit)
34+
return global_share
35+
36+
37+
def delete_by_conversation(
38+
conversation_id: str, user_id: str, with_commit: bool = True
39+
):
40+
(
41+
session.query(ConversationGlobalShare)
42+
.filter(
43+
ConversationGlobalShare.conversation_id == conversation_id,
44+
ConversationGlobalShare.shared_by == user_id,
45+
)
46+
.delete()
47+
)
48+
general.flush_or_commit(with_commit)
49+
50+
51+
def get_by_user(project_id: str, user_id: str) -> List[Dict[str, Any]]:
52+
conversation_global_shares = (
53+
session.query(ConversationGlobalShare, CognitionConversation.header)
54+
.join(
55+
CognitionConversation,
56+
ConversationGlobalShare.conversation_id == CognitionConversation.id,
57+
)
58+
.filter(ConversationGlobalShare.shared_by == user_id)
59+
.filter(CognitionConversation.project_id == project_id)
60+
.all()
61+
)
62+
conversation_global_shares_dict = []
63+
for share_obj, header in conversation_global_shares:
64+
share_dict = sql_alchemy_to_dict(share_obj)
65+
share_dict["conversation_header"] = header
66+
conversation_global_shares_dict.append(share_dict)
67+
return conversation_global_shares_dict
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
from operator import or_
2+
from typing import List, Optional
3+
from ..business_objects import general
4+
from ..session import session
5+
from ..models import ConversationShare, CognitionConversation
6+
from submodules.model.util import sql_alchemy_to_dict
7+
8+
9+
def get(share_id: str, user_id: str) -> Optional[ConversationShare]:
10+
return (
11+
session.query(ConversationShare)
12+
.filter(ConversationShare.id == share_id)
13+
.filter(
14+
or_(
15+
ConversationShare.shared_by == user_id,
16+
ConversationShare.shared_with == user_id,
17+
)
18+
)
19+
.first()
20+
)
21+
22+
23+
def get_all_by_conversation(conversation_id: str) -> List[ConversationShare]:
24+
return (
25+
session.query(ConversationShare)
26+
.filter(ConversationShare.conversation_id == conversation_id)
27+
.all()
28+
)
29+
30+
31+
def update_by_conversation(
32+
conversation_id: str,
33+
user_id: str,
34+
shared_with: List[str],
35+
can_copy: Optional[bool] = None,
36+
with_commit: bool = True,
37+
) -> List[ConversationShare]:
38+
existing_shares = (
39+
session.query(ConversationShare)
40+
.filter(ConversationShare.conversation_id == conversation_id)
41+
.filter(ConversationShare.shared_by == user_id)
42+
.all()
43+
)
44+
45+
existing_shared_with = {share.shared_with: share for share in existing_shares}
46+
shared_with_set = set(shared_with)
47+
48+
for share in existing_shares:
49+
if share.shared_with not in shared_with_set:
50+
general.delete(share, with_commit=False)
51+
52+
for sharing_user_id in shared_with:
53+
if sharing_user_id not in existing_shared_with:
54+
share = ConversationShare(
55+
conversation_id=conversation_id,
56+
shared_with=sharing_user_id,
57+
shared_by=user_id,
58+
can_copy=can_copy if can_copy is not None else False,
59+
)
60+
general.add(share, with_commit=False)
61+
else:
62+
if can_copy is not None:
63+
existing_shared_with[sharing_user_id].can_copy = can_copy
64+
65+
general.flush_or_commit(with_commit)
66+
67+
updated_shares = (
68+
session.query(ConversationShare)
69+
.filter(ConversationShare.conversation_id == conversation_id)
70+
.filter(ConversationShare.shared_by == user_id)
71+
.all()
72+
)
73+
return updated_shares
74+
75+
76+
def get_all_shared_by_or_for_user(
77+
project_id: str, user_id: str, with_header: bool = True
78+
) -> List[ConversationShare]:
79+
conversation_shares = (
80+
session.query(ConversationShare, CognitionConversation.header)
81+
.join(
82+
CognitionConversation,
83+
ConversationShare.conversation_id == CognitionConversation.id,
84+
)
85+
.filter(CognitionConversation.project_id == project_id)
86+
.filter(
87+
or_(
88+
ConversationShare.shared_by == user_id,
89+
ConversationShare.shared_with == user_id,
90+
)
91+
)
92+
.all()
93+
)
94+
conversation_shares_dict = []
95+
for share_obj, header in conversation_shares:
96+
share_dict = sql_alchemy_to_dict(share_obj)
97+
share_dict["conversation_header"] = header
98+
conversation_shares_dict.append(share_dict)
99+
100+
return conversation_shares_dict
101+
102+
103+
def get_all_shared_by_user(user_id: str) -> List[ConversationShare]:
104+
return (
105+
session.query(ConversationShare)
106+
.filter(ConversationShare.shared_by == user_id)
107+
.all()
108+
)
109+
110+
111+
def create(
112+
conversation_id: str,
113+
shared_with: str,
114+
shared_by: str,
115+
can_copy: bool = False,
116+
with_commit: bool = True,
117+
) -> ConversationShare:
118+
share = ConversationShare(
119+
conversation_id=conversation_id,
120+
shared_with=shared_with,
121+
shared_by=shared_by,
122+
can_copy=can_copy,
123+
)
124+
general.add(share, with_commit)
125+
return share
126+
127+
128+
def create_many(
129+
conversation_id: str,
130+
shared_with_user_ids: List[str],
131+
shared_by: str,
132+
can_copy: bool = False,
133+
with_commit: bool = True,
134+
) -> List[ConversationShare]:
135+
136+
shares = [
137+
ConversationShare(
138+
conversation_id=conversation_id,
139+
shared_with=user_id,
140+
shared_by=shared_by,
141+
can_copy=can_copy,
142+
)
143+
for user_id in shared_with_user_ids
144+
]
145+
general.add_all(shares, with_commit=True)
146+
return shares
147+
148+
149+
def update(
150+
share_id: str,
151+
user_id: str,
152+
can_copy: Optional[bool] = None,
153+
with_commit: bool = True,
154+
) -> Optional[ConversationShare]:
155+
share_entity = get(share_id)
156+
if share_entity is None:
157+
return None
158+
159+
if str(share_entity.shared_by) != user_id:
160+
raise ValueError("You are not allowed to update this sharing context.")
161+
162+
if can_copy is not None:
163+
share_entity.can_copy = can_copy
164+
165+
general.flush_or_commit(with_commit)
166+
return share_entity
167+
168+
169+
def delete_shared_with(
170+
conversation_share_id: str, user_id: str, with_commit: bool = True
171+
) -> None:
172+
session.query(ConversationShare).filter(
173+
ConversationShare.id == conversation_share_id
174+
).filter(ConversationShare.shared_with == user_id).delete()
175+
general.flush_or_commit(with_commit)
176+
177+
178+
def delete_shared_by_by_conversation_id(
179+
conversation_id: str, user_id: str, with_commit: bool = True
180+
) -> None:
181+
session.query(ConversationShare).filter(
182+
ConversationShare.conversation_id == conversation_id
183+
).filter(ConversationShare.shared_by == user_id).delete()
184+
general.flush_or_commit(with_commit)

cognition_objects/conversation_tags.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,16 +128,32 @@ def delete_association(
128128
general.flush_or_commit(with_commit)
129129

130130

131+
def delete_associations_by_conversation(
132+
conversation_id: str,
133+
with_commit: bool = True,
134+
) -> None:
135+
136+
session.query(CognitionConversationTagAssociation).filter(
137+
CognitionConversationTagAssociation.conversation_id == conversation_id,
138+
).delete(synchronize_session=False)
139+
140+
general.flush_or_commit(with_commit)
141+
142+
131143
def get_lookup_by_conversation_ids(
132-
conversation_ids: List[str],
144+
conversation_ids: List[str], user_id: Optional[str] = None
133145
) -> Dict[str, List[Dict[str, Any]]]:
134-
associations = (
135-
session.query(CognitionConversationTagAssociation)
136-
.filter(
137-
CognitionConversationTagAssociation.conversation_id.in_(conversation_ids)
138-
)
139-
.all()
146+
query = session.query(CognitionConversationTagAssociation)
147+
if user_id is not None:
148+
query = query.join(
149+
CognitionConversationTag,
150+
CognitionConversationTag.id == CognitionConversationTagAssociation.tag_id,
151+
).filter(CognitionConversationTag.created_by == user_id)
152+
query = query.filter(
153+
CognitionConversationTagAssociation.conversation_id.in_(conversation_ids)
140154
)
155+
associations = query.all()
156+
141157
tag_lookup: Dict[str, List[Dict[str, Any]]] = {}
142158

143159
for association in associations:

cognition_objects/project.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def get_org_id(project_id: str) -> str:
3535
raise ValueError(f"Project with id {project_id} not found")
3636

3737

38-
def get_by_user(project_id: str, user_id: str) -> CognitionProject:
38+
def get_by_user(project_id: str, user_id: str) -> List[Dict[str, Any]]:
3939
user_item = user.get(user_id)
4040
if user_item.role == enums.UserRoles.ENGINEER.value:
4141
return get(project_id)
@@ -215,6 +215,8 @@ def update(
215215
llm_config: Optional[Dict[str, Any]] = None,
216216
tokenizer: Optional[str] = None,
217217
icon: Optional[str] = None,
218+
allow_conversation_sharing_organization: Optional[bool] = None,
219+
allow_conversation_sharing_global: Optional[bool] = None,
218220
with_commit: bool = True,
219221
) -> CognitionProject:
220222
project: CognitionProject = get(project_id)
@@ -288,6 +290,12 @@ def update(
288290
project.tokenizer = tokenizer
289291
if icon is not None:
290292
project.icon = icon
293+
if allow_conversation_sharing_organization is not None:
294+
project.allow_conversation_sharing_organization = (
295+
allow_conversation_sharing_organization
296+
)
297+
if allow_conversation_sharing_global is not None:
298+
project.allow_conversation_sharing_global = allow_conversation_sharing_global
291299
general.flush_or_commit(with_commit)
292300
return project
293301

enums.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ class Tablenames(Enum):
178178
ADMIN_QUERY_MESSAGE_SUMMARY = "admin_query_message_summary"
179179
RELEASE_NOTIFICATION = "release_notification"
180180
TIMED_EXECUTIONS = "timed_executions"
181+
CONVERSATION_SHARE = "conversation_share"
182+
CONVERSATION_GLOBAL_SHARE = "conversation_global_share"
181183

182184
def snake_case_to_pascal_case(self):
183185
# the type name (written in PascalCase) of a table is needed to create backrefs

0 commit comments

Comments
 (0)