the-algorithm/tweetypie/common/src/scala/com/twitter/tweetypie/util/TweetCreationLock.scala

204 lines
7.3 KiB
Scala

package com.twitter.tweetypie.util
import com.twitter.conversions.DurationOps._
import com.twitter.logging.Logger
import com.twitter.mediaservices.commons.mediainformation.thriftscala.UserDefinedProductMetadata
import com.twitter.scrooge.BinaryThriftStructSerializer
import com.twitter.servo.cache.ScopedCacheKey
import com.twitter.servo.util.Transformer
import com.twitter.tweetypie.thriftscala.PostTweetRequest
import com.twitter.util.Base64Long
import com.twitter.util.Time
import java.nio.ByteBuffer
import java.security.MessageDigest
import org.apache.commons.codec.binary.Base64
import scala.collection.immutable.SortedMap
object TweetCreationLock {
case class Key private (userId: UserId, typeCode: String, idOrMd5: String)
extends ScopedCacheKey("t", "locker", 2, Base64Long.toBase64(userId), typeCode, idOrMd5) {
def uniquenessId: Option[String] =
if (typeCode == Key.TypeCode.UniquenessId) Some(idOrMd5) else None
}
object Key {
private[this] val log = Logger(getClass)
object TypeCode {
val SourceTweetId = "r"
val UniquenessId = "u"
val PostTweetRequest = "p"
}
private[this] val serializer = BinaryThriftStructSerializer(PostTweetRequest)
// normalize the representation of no media ids.
private[util] def sanitizeMediaUploadIds(mediaUploadIds: Option[Seq[Long]]) =
mediaUploadIds.filter(_.nonEmpty)
/**
* Request deduplication depends on the hash of a serialized Thrift value.
*
* In order to guarantee that a Map has a reproducible serialized form,
* it's necessary to fix the ordering of its keys.
*/
private[util] def sanitizeMediaMetadata(
mediaMetadata: Option[scala.collection.Map[MediaId, UserDefinedProductMetadata]]
): Option[scala.collection.Map[MediaId, UserDefinedProductMetadata]] =
mediaMetadata.map(m => SortedMap(m.toSeq: _*))
/**
* Make sure to sanitize request fields with map/set since serialized
* bytes ordering is not guaranteed for same thrift values.
*/
private[util] def sanitizeRequest(request: PostTweetRequest): PostTweetRequest =
PostTweetRequest(
userId = request.userId,
text = request.text,
createdVia = "",
inReplyToTweetId = request.inReplyToTweetId,
geo = request.geo,
mediaUploadIds = sanitizeMediaUploadIds(request.mediaUploadIds),
narrowcast = request.narrowcast,
nullcast = request.nullcast,
additionalFields = request.additionalFields,
attachmentUrl = request.attachmentUrl,
mediaMetadata = sanitizeMediaMetadata(request.mediaMetadata),
conversationControl = request.conversationControl,
underlyingCreativesContainerId = request.underlyingCreativesContainerId,
editOptions = request.editOptions,
noteTweetOptions = request.noteTweetOptions
)
def bySourceTweetId(userId: UserId, sourceTweetId: TweetId): Key =
Key(userId, TypeCode.SourceTweetId, Base64Long.toBase64(sourceTweetId))
def byRequest(request: PostTweetRequest): Key =
request.uniquenessId match {
case Some(uqid) =>
byUniquenessId(request.userId, uqid)
case None =>
val sanitized = sanitizeRequest(request)
val sanitizedBytes = serializer.toBytes(sanitized)
val digested = MessageDigest.getInstance("SHA-256").digest(sanitizedBytes)
val base64Digest = Base64.encodeBase64String(digested)
val key = Key(request.userId, TypeCode.PostTweetRequest, base64Digest)
log.ifDebug(s"Generated key $key from request:\n${sanitized}")
key
}
/**
* Key for tweets that have a uniqueness id set. There is only one
* namespace of uniqueness ids, across all clients. They are
* expected to be Snowflake ids, in order to avoid cache
* collisions.
*/
def byUniquenessId(userId: UserId, uniquenessId: Long): Key =
Key(userId, TypeCode.UniquenessId, Base64Long.toBase64(uniquenessId))
}
/**
* The state of tweet creation for a given Key (request).
*/
sealed trait State
object State {
/**
* There is no tweet creation currently in progress. (This can
* either be represented by no entry in the cache, or this special
* marker. This lets us use checkAndSet for deletion to avoid
* accidentally overwriting other process' values.)
*/
case object Unlocked extends State
/**
* Some process is attempting to create the tweet.
*/
case class InProgress(token: Long, timestamp: Time) extends State
/**
* The tweet has already been successfully created, and has the
* specified id.
*/
case class AlreadyCreated(tweetId: TweetId, timestamp: Time) extends State
/**
* When stored in cache, each state is prefixed by a byte
* indicating the type of the entry.
*/
object TypeCode {
val Unlocked: Byte = 0.toByte
val InProgress: Byte = 1.toByte // + random long + timestamp
val AlreadyCreated: Byte = 2.toByte // + tweet id + timestamp
}
private[this] val BufferSize = 17 // type byte + 64-bit value + 64-bit timestamp
// Constant buffer to use for storing the serialized form on
// Unlocked.
private[this] val UnlockedBuf = Array[Byte](TypeCode.Unlocked)
// Store the serialization function in a ThreadLocal so that we can
// reuse the buffer between invocations.
private[this] val threadLocalSerialize = new ThreadLocal[State => Array[Byte]] {
override def initialValue(): State => Array[Byte] = {
// Allocate the thread-local state
val ary = new Array[Byte](BufferSize)
val buf = ByteBuffer.wrap(ary)
{
case Unlocked => UnlockedBuf
case InProgress(token, timestamp) =>
buf.clear()
buf
.put(TypeCode.InProgress)
.putLong(token)
.putLong(timestamp.sinceEpoch.inNanoseconds)
ary
case AlreadyCreated(tweetId, timestamp) =>
buf.clear()
buf
.put(TypeCode.AlreadyCreated)
.putLong(tweetId)
.putLong(timestamp.sinceEpoch.inNanoseconds)
ary
}
}
}
/**
* Convert this State to the cache representation.
*/
private[this] def toBytes(state: State): Array[Byte] =
threadLocalSerialize.get()(state)
/**
* Convert this byte array into a LockState.
*
* @throws RuntimeException if the buffer is not of the right size
* and format
*/
private[this] def fromBytes(bytes: Array[Byte]): State = {
val buf = ByteBuffer.wrap(bytes)
val result = buf.get() match {
case TypeCode.Unlocked => Unlocked
case TypeCode.InProgress => InProgress(buf.getLong(), buf.getLong().nanoseconds.afterEpoch)
case TypeCode.AlreadyCreated =>
AlreadyCreated(buf.getLong(), buf.getLong().nanoseconds.afterEpoch)
case other => throw new RuntimeException("Invalid type code: " + other)
}
if (buf.remaining != 0) {
throw new RuntimeException("Extra data in buffer: " + bytes)
}
result
}
/**
* How to serialize the State for storage in cache.
*/
val Serializer: Transformer[State, Array[Byte]] =
Transformer[State, Array[Byte]](tTo = toBytes _, tFrom = fromBytes _)
}
}