mas_handlers/activity_tracker/
worker.rs1use std::{collections::HashMap, net::IpAddr};
8
9use chrono::{DateTime, Utc};
10use mas_storage::{
11 BoxRepositoryFactory, RepositoryAccess, RepositoryError, user::BrowserSessionRepository,
12};
13use opentelemetry::{
14 Key, KeyValue,
15 metrics::{Counter, Gauge, Histogram},
16};
17use tokio_util::sync::CancellationToken;
18use ulid::Ulid;
19
20use crate::{
21 METER,
22 activity_tracker::{Message, SessionKind},
23};
24
25static MAX_PENDING_RECORDS: usize = 1000;
31
32const TYPE: Key = Key::from_static_str("type");
33const SESSION_KIND: Key = Key::from_static_str("session_kind");
34const RESULT: Key = Key::from_static_str("result");
35
36#[derive(Clone, Copy, Debug)]
37struct ActivityRecord {
38 #[allow(dead_code)]
40 start_time: DateTime<Utc>,
41 end_time: DateTime<Utc>,
42 ip: Option<IpAddr>,
43}
44
45pub struct Worker {
47 repository_factory: BoxRepositoryFactory,
48 pending_records: HashMap<(SessionKind, Ulid), ActivityRecord>,
49 pending_records_gauge: Gauge<u64>,
50 message_counter: Counter<u64>,
51 flush_time_histogram: Histogram<u64>,
52}
53
54impl Worker {
55 pub(crate) fn new(repository_factory: BoxRepositoryFactory) -> Self {
56 let message_counter = METER
57 .u64_counter("mas.activity_tracker.messages")
58 .with_description("The number of messages received by the activity tracker")
59 .with_unit("{messages}")
60 .build();
61
62 for kind in &[
64 SessionKind::OAuth2,
65 SessionKind::Compat,
66 SessionKind::Browser,
67 ] {
68 message_counter.add(
69 0,
70 &[
71 KeyValue::new(TYPE, "record"),
72 KeyValue::new(SESSION_KIND, kind.as_str()),
73 ],
74 );
75 }
76 message_counter.add(0, &[KeyValue::new(TYPE, "flush")]);
77 message_counter.add(0, &[KeyValue::new(TYPE, "shutdown")]);
78
79 let flush_time_histogram = METER
80 .u64_histogram("mas.activity_tracker.flush_time")
81 .with_description("The time it took to flush the activity tracker")
82 .with_unit("ms")
83 .build();
84
85 let pending_records_gauge = METER
86 .u64_gauge("mas.activity_tracker.pending_records")
87 .with_description("The number of pending activity records")
88 .with_unit("{records}")
89 .build();
90 pending_records_gauge.record(0, &[]);
91
92 Self {
93 repository_factory,
94 pending_records: HashMap::with_capacity(MAX_PENDING_RECORDS),
95 pending_records_gauge,
96 message_counter,
97 flush_time_histogram,
98 }
99 }
100
101 pub(super) async fn run(
102 mut self,
103 mut receiver: tokio::sync::mpsc::Receiver<Message>,
104 cancellation_token: CancellationToken,
105 ) {
106 let _guard = cancellation_token.clone().drop_guard();
109
110 loop {
111 let message = tokio::select! {
112 () = cancellation_token.cancelled(), if !receiver.is_closed() => {
115 receiver.close();
118 tracing::debug!("Shutting down activity tracker");
119 continue;
120 },
121
122 message = receiver.recv() => {
123 let Some(message) = message else { break };
125 message
126 }
127 };
128
129 match message {
130 Message::Record {
131 kind,
132 id,
133 date_time,
134 ip,
135 } => {
136 if self.pending_records.len() >= MAX_PENDING_RECORDS {
137 tracing::warn!("Too many pending activity records, flushing");
138 self.flush().await;
139 }
140
141 if self.pending_records.len() >= MAX_PENDING_RECORDS {
142 tracing::error!(
143 kind = kind.as_str(),
144 %id,
145 %date_time,
146 "Still too many pending activity records, dropping"
147 );
148 continue;
149 }
150
151 self.message_counter.add(
152 1,
153 &[
154 KeyValue::new(TYPE, "record"),
155 KeyValue::new(SESSION_KIND, kind.as_str()),
156 ],
157 );
158
159 let record =
160 self.pending_records
161 .entry((kind, id))
162 .or_insert_with(|| ActivityRecord {
163 start_time: date_time,
164 end_time: date_time,
165 ip,
166 });
167
168 record.end_time = date_time.max(record.end_time);
169 }
170
171 Message::Flush(tx) => {
172 self.message_counter.add(1, &[KeyValue::new(TYPE, "flush")]);
173
174 self.flush().await;
175 let _ = tx.send(());
176 }
177 }
178
179 self.pending_records_gauge
181 .record(self.pending_records.len() as u64, &[]);
182 }
183
184 self.flush().await;
186 }
187
188 async fn flush(&mut self) {
190 if self.pending_records.is_empty() {
192 return;
193 }
194
195 let start = std::time::Instant::now();
196 let res = self.try_flush().await;
197
198 let duration = start.elapsed();
200 let duration_ms = duration.as_millis().try_into().unwrap_or(u64::MAX);
201
202 match res {
203 Ok(()) => {
204 self.flush_time_histogram
205 .record(duration_ms, &[KeyValue::new(RESULT, "success")]);
206 }
207 Err(e) => {
208 self.flush_time_histogram
209 .record(duration_ms, &[KeyValue::new(RESULT, "failure")]);
210 tracing::error!(
211 error = &e as &dyn std::error::Error,
212 "Failed to flush activity tracker"
213 );
214 }
215 }
216 }
217
218 #[tracing::instrument(name = "activity_tracker.flush", skip(self))]
220 async fn try_flush(&mut self) -> Result<(), RepositoryError> {
221 let pending_records = &self.pending_records;
222 let mut repo = self.repository_factory.create().await?;
223
224 let mut browser_sessions = Vec::new();
225 let mut oauth2_sessions = Vec::new();
226 let mut compat_sessions = Vec::new();
227
228 for ((kind, id), record) in pending_records {
229 match kind {
230 SessionKind::Browser => {
231 browser_sessions.push((*id, record.end_time, record.ip));
232 }
233 SessionKind::OAuth2 => {
234 oauth2_sessions.push((*id, record.end_time, record.ip));
235 }
236 SessionKind::Compat => {
237 compat_sessions.push((*id, record.end_time, record.ip));
238 }
239 }
240 }
241
242 tracing::info!(
243 "Flushing {} activity records to the database",
244 pending_records.len()
245 );
246
247 repo.browser_session()
248 .record_batch_activity(browser_sessions)
249 .await?;
250 repo.oauth2_session()
251 .record_batch_activity(oauth2_sessions)
252 .await?;
253 repo.compat_session()
254 .record_batch_activity(compat_sessions)
255 .await?;
256
257 repo.save().await?;
258 self.pending_records.clear();
259
260 Ok(())
261 }
262}