mas_handlers/activity_tracker/
worker.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::{collections::HashMap, net::IpAddr};
8
9use chrono::{DateTime, Utc};
10use mas_storage::{
11    BoxRepositoryFactory, RepositoryAccess, RepositoryError, user::BrowserSessionRepository,
12};
13use opentelemetry::{
14    Key, KeyValue,
15    metrics::{Counter, Gauge, Histogram},
16};
17use tokio_util::sync::CancellationToken;
18use ulid::Ulid;
19
20use crate::{
21    METER,
22    activity_tracker::{Message, SessionKind},
23};
24
25/// The maximum number of pending activity records before we flush them to the
26/// database automatically.
27///
28/// The [`ActivityRecord`] structure plus the key in the [`HashMap`] takes less
29/// than 100 bytes, so this should allocate around 100kB of memory.
30static MAX_PENDING_RECORDS: usize = 1000;
31
32const TYPE: Key = Key::from_static_str("type");
33const SESSION_KIND: Key = Key::from_static_str("session_kind");
34const RESULT: Key = Key::from_static_str("result");
35
36#[derive(Clone, Copy, Debug)]
37struct ActivityRecord {
38    // XXX: We don't actually use the start time for now
39    #[allow(dead_code)]
40    start_time: DateTime<Utc>,
41    end_time: DateTime<Utc>,
42    ip: Option<IpAddr>,
43}
44
45/// Handles writing activity records to the database.
46pub struct Worker {
47    repository_factory: BoxRepositoryFactory,
48    pending_records: HashMap<(SessionKind, Ulid), ActivityRecord>,
49    pending_records_gauge: Gauge<u64>,
50    message_counter: Counter<u64>,
51    flush_time_histogram: Histogram<u64>,
52}
53
54impl Worker {
55    pub(crate) fn new(repository_factory: BoxRepositoryFactory) -> Self {
56        let message_counter = METER
57            .u64_counter("mas.activity_tracker.messages")
58            .with_description("The number of messages received by the activity tracker")
59            .with_unit("{messages}")
60            .build();
61
62        // Record stuff on the counter so that the metrics are initialized
63        for kind in &[
64            SessionKind::OAuth2,
65            SessionKind::Compat,
66            SessionKind::Browser,
67        ] {
68            message_counter.add(
69                0,
70                &[
71                    KeyValue::new(TYPE, "record"),
72                    KeyValue::new(SESSION_KIND, kind.as_str()),
73                ],
74            );
75        }
76        message_counter.add(0, &[KeyValue::new(TYPE, "flush")]);
77        message_counter.add(0, &[KeyValue::new(TYPE, "shutdown")]);
78
79        let flush_time_histogram = METER
80            .u64_histogram("mas.activity_tracker.flush_time")
81            .with_description("The time it took to flush the activity tracker")
82            .with_unit("ms")
83            .build();
84
85        let pending_records_gauge = METER
86            .u64_gauge("mas.activity_tracker.pending_records")
87            .with_description("The number of pending activity records")
88            .with_unit("{records}")
89            .build();
90        pending_records_gauge.record(0, &[]);
91
92        Self {
93            repository_factory,
94            pending_records: HashMap::with_capacity(MAX_PENDING_RECORDS),
95            pending_records_gauge,
96            message_counter,
97            flush_time_histogram,
98        }
99    }
100
101    pub(super) async fn run(
102        mut self,
103        mut receiver: tokio::sync::mpsc::Receiver<Message>,
104        cancellation_token: CancellationToken,
105    ) {
106        // This guard on the shutdown token is to ensure that if this task crashes for
107        // any reason, the server will shut down
108        let _guard = cancellation_token.clone().drop_guard();
109
110        loop {
111            let message = tokio::select! {
112                // Because we want the cancellation token to trigger only once,
113                // we looked whether we closed the channel or not
114                () = cancellation_token.cancelled(), if !receiver.is_closed() => {
115                    // We only close the channel, which will make it flush all
116                    // the pending messages
117                    receiver.close();
118                    tracing::debug!("Shutting down activity tracker");
119                    continue;
120                },
121
122                message = receiver.recv()  => {
123                    // We consumed all the messages, break out of the loop
124                    let Some(message) = message else { break };
125                    message
126                }
127            };
128
129            match message {
130                Message::Record {
131                    kind,
132                    id,
133                    date_time,
134                    ip,
135                } => {
136                    if self.pending_records.len() >= MAX_PENDING_RECORDS {
137                        tracing::warn!("Too many pending activity records, flushing");
138                        self.flush().await;
139                    }
140
141                    if self.pending_records.len() >= MAX_PENDING_RECORDS {
142                        tracing::error!(
143                            kind = kind.as_str(),
144                            %id,
145                            %date_time,
146                            "Still too many pending activity records, dropping"
147                        );
148                        continue;
149                    }
150
151                    self.message_counter.add(
152                        1,
153                        &[
154                            KeyValue::new(TYPE, "record"),
155                            KeyValue::new(SESSION_KIND, kind.as_str()),
156                        ],
157                    );
158
159                    let record =
160                        self.pending_records
161                            .entry((kind, id))
162                            .or_insert_with(|| ActivityRecord {
163                                start_time: date_time,
164                                end_time: date_time,
165                                ip,
166                            });
167
168                    record.end_time = date_time.max(record.end_time);
169                }
170
171                Message::Flush(tx) => {
172                    self.message_counter.add(1, &[KeyValue::new(TYPE, "flush")]);
173
174                    self.flush().await;
175                    let _ = tx.send(());
176                }
177            }
178
179            // Update the gauge
180            self.pending_records_gauge
181                .record(self.pending_records.len() as u64, &[]);
182        }
183
184        // Flush one last time
185        self.flush().await;
186    }
187
188    /// Flush the activity tracker.
189    async fn flush(&mut self) {
190        // Short path: if there are no pending records, we don't need to flush
191        if self.pending_records.is_empty() {
192            return;
193        }
194
195        let start = std::time::Instant::now();
196        let res = self.try_flush().await;
197
198        // Measure the time it took to flush the activity tracker
199        let duration = start.elapsed();
200        let duration_ms = duration.as_millis().try_into().unwrap_or(u64::MAX);
201
202        match res {
203            Ok(()) => {
204                self.flush_time_histogram
205                    .record(duration_ms, &[KeyValue::new(RESULT, "success")]);
206            }
207            Err(e) => {
208                self.flush_time_histogram
209                    .record(duration_ms, &[KeyValue::new(RESULT, "failure")]);
210                tracing::error!(
211                    error = &e as &dyn std::error::Error,
212                    "Failed to flush activity tracker"
213                );
214            }
215        }
216    }
217
218    /// Fallible part of [`Self::flush`].
219    #[tracing::instrument(name = "activity_tracker.flush", skip(self))]
220    async fn try_flush(&mut self) -> Result<(), RepositoryError> {
221        let pending_records = &self.pending_records;
222        let mut repo = self.repository_factory.create().await?;
223
224        let mut browser_sessions = Vec::new();
225        let mut oauth2_sessions = Vec::new();
226        let mut compat_sessions = Vec::new();
227
228        for ((kind, id), record) in pending_records {
229            match kind {
230                SessionKind::Browser => {
231                    browser_sessions.push((*id, record.end_time, record.ip));
232                }
233                SessionKind::OAuth2 => {
234                    oauth2_sessions.push((*id, record.end_time, record.ip));
235                }
236                SessionKind::Compat => {
237                    compat_sessions.push((*id, record.end_time, record.ip));
238                }
239            }
240        }
241
242        tracing::info!(
243            "Flushing {} activity records to the database",
244            pending_records.len()
245        );
246
247        repo.browser_session()
248            .record_batch_activity(browser_sessions)
249            .await?;
250        repo.oauth2_session()
251            .record_batch_activity(oauth2_sessions)
252            .await?;
253        repo.compat_session()
254            .record_batch_activity(compat_sessions)
255            .await?;
256
257        repo.save().await?;
258        self.pending_records.clear();
259
260        Ok(())
261    }
262}