diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..5ca0973f8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.DS_Store + diff --git a/COPYING b/COPYING new file mode 100644 index 000000000..be3f7b28e --- /dev/null +++ b/COPYING @@ -0,0 +1,661 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/IMPORTANT.txt b/IMPORTANT.txt new file mode 100644 index 000000000..f98c882f4 --- /dev/null +++ b/IMPORTANT.txt @@ -0,0 +1,15 @@ +⢀⡴⠑⡄⠀⠀⠀⠀⠀⠀⠀⣀⣀⣤⣤⣤⣀⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ +⠸⡇⠀⠿⡀⠀⠀⠀⣀⡴⢿⣿⣿⣿⣿⣿⣿⣿⣷⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀ +⠀⠀⠀⠀⠑⢄⣠⠾⠁⣀⣄⡈⠙⣿⣿⣿⣿⣿⣿⣿⣿⣆⠀⠀⠀⠀⠀⠀⠀⠀ +⠀⠀⠀⠀⢀⡀⠁⠀⠀⠈⠙⠛⠂⠈⣿⣿⣿⣿⣿⠿⡿⢿⣆⠀⠀⠀⠀⠀⠀⠀ +⠀⠀⠀⢀⡾⣁⣀⠀⠴⠂⠙⣗⡀⠀⢻⣿⣿⠭⢤⣴⣦⣤⣹⠀⠀⠀⢀⢴⣶⣆ +⠀⠀⢀⣾⣿⣿⣿⣷⣮⣽⣾⣿⣥⣴⣿⣿⡿⢂⠔⢚⡿⢿⣿⣦⣴⣾⠁⠸⣼⡿ +⠀⢀⡞⠁⠙⠻⠿⠟⠉⠀⠛⢹⣿⣿⣿⣿⣿⣌⢤⣼⣿⣾⣿⡟⠉⠀⠀⠀⠀⠀ +⠀⣾⣷⣶⠇⠀⠀⣤⣄⣀⡀⠈⠻⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⡇⠀⠀⠀⠀⠀⠀ +⠀⠉⠈⠉⠀⠀⢦⡈⢻⣿⣿⣿⣶⣶⣶⣶⣤⣽⡹⣿⣿⣿⣿⡇⠀⠀⠀⠀⠀⠀ +⠀⠀⠀⠀⠀⠀⠀⠉⠲⣽⡻⢿⣿⣿⣿⣿⣿⣿⣷⣜⣿⣿⣿⡇⠀⠀⠀⠀⠀⠀ +⠀⠀⠀⠀⠀⠀⠀⠀⢸⣿⣿⣷⣶⣮⣭⣽⣿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⠀ +⠀⠀⠀⠀⠀⠀⣀⣀⣈⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠇⠀⠀⠀⠀⠀⠀⠀ +⠀⠀⠀⠀⠀⠀⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠃⠀⠀⠀⠀⠀⠀⠀⠀ +⠀⠀⠀⠀⠀⠀⠀⠹⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⡿⠟⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀ +⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠛⠻⠿⠿⠿⠿⠛⠉ \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 000000000..056cc0770 --- /dev/null +++ b/README.md @@ -0,0 +1,39 @@ +# Twitter Recommendation Algorithm + +The Twitter Recommendation Algorithm is a set of services and jobs that are responsible for constructing and serving the +Home Timeline. For an introduction to how the algorithm works, please refer to our [engineering blog](https://blog.twitter.com/engineering/en_us/topics/open-source/2023/twitter-recommendation-algorithm). The +diagram below illustrates how major services and jobs interconnect. + +![](docs/system-diagram.png) + +These are the main components of the Recommendation Algorithm included in this repository: + +| Type | Component | Description | +|------------|------------|------------| +| Feature | [SimClusters](src/scala/com/twitter/simclusters_v2/README.md) | Community detection and sparse embeddings into those communities. | +| | [TwHIN](https://github.com/twitter/the-algorithm-ml/blob/main/projects/twhin/README.md) | Dense knowledge graph embeddings for Users and Tweets. | +| | [trust-and-safety-models](trust_and_safety_models/README.md) | Models for detecting NSFW or abusive content. | +| | [real-graph](src/scala/com/twitter/interaction_graph/README.md) | Model to predict likelihood of a Twitter User interacting with another User. | +| | [tweepcred](src/scala/com/twitter/graph/batch/job/tweepcred/README) | Page-Rank algorithm for calculating Twitter User reputation. | +| | [recos-injector](recos-injector/README.md) | Streaming event processor for building input streams for [GraphJet](https://github.com/twitter/GraphJet) based services. | +| | [graph-feature-service](graph-feature-service/README.md) | Serves graph features for a directed pair of Users (e.g. how many of User A's following liked Tweets from User B). | +| Candidate Source | [search-index](src/java/com/twitter/search/README.md) | Find and rank In-Network Tweets. ~50% of Tweets come from this candidate source. | +| | [cr-mixer](cr-mixer/README.md) | Coordination layer for fetching Out-of-Network tweet candidates from underlying compute services. | +| | [user-tweet-entity-graph](src/scala/com/twitter/recos/user_tweet_entity_graph/README.md) (UTEG)| Maintains an in memory User to Tweet interaction graph, and finds candidates based on traversals of this graph. This is built on the [GraphJet](https://github.com/twitter/GraphJet) framework. Several other GraphJet based features and candidate sources are located [here](src/scala/com/twitter/recos) | +| | [follow-recommendation-service](follow-recommendations-service/README.md) (FRS)| Provides Users with recommendations for accounts to follow, and Tweets from those accounts. | +| Ranking | [light-ranker](src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/README.md) | Light ranker model used by search index (Earlybird) to rank Tweets. | +| | [heavy-ranker](https://github.com/twitter/the-algorithm-ml/blob/main/projects/home/recap/README.md) | Neural network for ranking candidate tweets. One of the main signals used to select timeline Tweets post candidate sourcing. | +| Tweet mixing & filtering | [home-mixer](home-mixer/README.md) | Main service used to construct and serve the Home Timeline. Built on [product-mixer](product-mixer/README.md) | +| | [visibility-filters](visibilitylib/README.md) | Responsible for filtering Twitter content to support legal compliance, improve product quality, increase user trust, protect revenue through the use of hard-filtering, visible product treatments, and coarse-grained downranking. | +| | [timelineranker](timelineranker/README.md) | Legacy service which provides relevance-scored tweets from the Earlybird Search Index and UTEG service. | +| Software framework | [navi](navi/navi/README.md) | High performance, machine learning model serving written in Rust. | +| | [product-mixer](product-mixer/README.md) | Software framework for building feeds of content. | +| | [twml](twml/README.md) | Legacy machine learning framework built on TensorFlow v1. | + +We include Bazel BUILD files for most components, but not a top level BUILD or WORKSPACE file. + +## Contributing + +We invite the community to submit GitHub issues and pull requests for suggestions on improving the recommendation algorithm. We are working on tools to manage these suggestions and sync changes to our internal repository. Any security concerns or issues should be routed to our official [bug bounty program](https://hackerone.com/twitter) through HackerOne. We hope to benefit from the collective intelligence and expertise of the global community in helping us identify issues and suggest improvements, ultimately leading to a better Twitter. + +Read our blog on the open source initiative [here](https://blog.twitter.com/en_us/topics/company/2023/a-new-era-of-transparency-for-twitter). diff --git a/twml/BUILD b/twml/BUILD new file mode 100644 index 000000000..c339f6fae --- /dev/null +++ b/twml/BUILD @@ -0,0 +1,186 @@ +twml_sources = [ + "twml/**/*.py", +] + +twml_deps = [ + "3rdparty/python/cherrypy:default", + "3rdparty/python/pyyaml:default", + "3rdparty/python/absl-py:default", + "3rdparty/python/joblib:default", + "3rdparty/python/kazoo:default", + "3rdparty/python/python-dateutil:default", + "3rdparty/python/pytz:default", + "cortex/ml-metastore/src/main/python/com/twitter/mlmetastore/modelrepo/client", + "src/python/twitter/common/app", + "src/python/twitter/common/app/modules:vars", + "src/python/twitter/common/metrics", + "src/python/twitter/deepbird/compat/v1/optimizers", + "src/python/twitter/deepbird/compat/v1/rnn", + "src/python/twitter/deepbird/hparam", + "src/python/twitter/deepbird/io", + "src/python/twitter/deepbird/io/legacy", + "src/python/twitter/deepbird/logging", + "src/python/twitter/deepbird/sparse", + "src/python/twitter/deepbird/stats_server", + "src/python/twitter/deepbird/util:simple-data-record-handler", + "src/python/twitter/deepbird/util/hashing", + "src/python/twitter/ml/api/dal", + "src/python/twitter/ml/common:metrics", + "src/python/twitter/ml/common/kubernetes", + "src/python/twitter/ml/common:resources", + "src/python/twitter/ml/twml/kubernetes", + "src/python/twitter/ml/twml:status", + "src/thrift/com/twitter/dal:dal_no_constants-python", + "src/thrift/com/twitter/statebird:compiled-v2-python", +] + +python3_library( + name = "twml-test-common-deps", + tags = ["no-mypy"], + dependencies = [ + "src/python/twitter/deepbird/util:inference", + "src/python/twitter/deepbird/util/data", + "src/thrift/com/twitter/ml/api:data-python", + "twml/tests/data:resources", + ], +) + +python3_library( + name = "twml_packer_deps_no_tf", + tags = [ + "bazel-compatible", + "no-mypy", + ], + dependencies = [ + "3rdparty/python/numpy:default", + "3rdparty/python/pandas:default", + "3rdparty/python/pyyaml:default", + "3rdparty/python/requests:default", + "3rdparty/python/scikit-learn:default", + "3rdparty/python/scipy:default", + "3rdparty/python/tensorflow-hub:default", + "3rdparty/python/thriftpy2:default", + ], +) + +python3_library( + name = "twml_packer_deps_no_tf_py3", + tags = [ + "known-to-fail-jira:CX-20246", + "no-mypy", + ], + dependencies = [ + ":twml_packer_deps_no_tf", + "3rdparty/python/tensorflow-model-analysis", + ], +) + +alias( + name = "twml-test-shared", + target = ":twml_common", +) + +python3_library( + name = "twml_common", + sources = ["twml_common/**/*.py"], + tags = [ + "bazel-compatible", + "no-mypy", + ], +) + +# Alias twml-dev to twml to avoid breaking user targets. +alias( + name = "twml-dev", + target = "twml", +) + +python3_library( + name = "twml-test-dev-deps", + tags = [ + "bazel-compatible", + "no-mypy", + ], + dependencies = [ + ":twml", + ":twml-test-common-deps", + ":twml-test-shared", + "3rdparty/python/freezegun:default", + "src/python/twitter/deepbird/keras/layers", + "src/thrift/com/twitter/ml/api:data-python", + "src/thrift/com/twitter/ml/prediction_service:prediction_service-python", + ], +) + +python3_library( + name = "twml-dev-python", + sources = twml_sources, + tags = [ + "bazel-compatible", + "no-mypy", + ], + dependencies = twml_deps + [ + ":twml_packer_deps_no_tf", + "3rdparty/python/tensorflow", + "3rdparty/python/twml:libtwml-universal", + "twml/libtwml:libtwml-python", + ], +) + +# Build a smaller .pex file that models can depend on. +# Tensorflow and other dependencies are downloaded from Packer on Aurora. +# Note: This gets the C++ ops through 3rdparty artifacts. +python3_library( + name = "twml-nodeps", + sources = twml_sources, + tags = [ + "bazel-compatible", + "no-mypy", + ], + dependencies = twml_deps + [ + "3rdparty/python/twml:libtwml-universal", + ], +) + +python3_library( + name = "twml", + tags = [ + "bazel-compatible", + "no-mypy", + ], + dependencies = [ + ":twml-nodeps", + ":twml_packer_deps_no_tf", + "3rdparty/python/tensorflow", + ], +) + +python37_binary( + name = "tensorboard", + source = "twml/tensorboard/__main__.py", + dependencies = [ + "3rdparty/python/_closures/twml:tensorboard", + "3rdparty/python/tensorflow", + ], +) + +python37_binary( + name = "saved_model_cli", + source = "twml/saved_model_cli/__main__.py", + dependencies = [ + "3rdparty/python/_closures/twml:saved_model_cli", + "3rdparty/python/tensorflow", + ], +) + +# This target is added so twml can be used regardless of the Tensorflow version: +# This target does not pull in TensorFlow 1.x or the related libtwml compiled using TF 1.x. +python3_library( + name = "twml-py-source-only", + sources = twml_sources, + tags = [ + "known-to-fail-jira:CX-23416", + "no-mypy", + ], + dependencies = twml_deps, +) diff --git a/twml/README.md b/twml/README.md new file mode 100644 index 000000000..df7a10328 --- /dev/null +++ b/twml/README.md @@ -0,0 +1,13 @@ +# TWML + +--- +Note: `twml` is no longer under development. Much of the code here is not out of date and unused. +It is included here for completeness, because `twml` is still used to train the light ranker models +(see `src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/README.md`) +--- + +TWML is one of Twitter's machine learning frameworks, which uses Tensorflow under the hood. While it is mostly +deprecated, +it is still currently used to train the Earlybird light ranking models ( +see `src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py`). +The most relevant part of this is the `DataRecordTrainer` class, which is where the core training logic resides. \ No newline at end of file diff --git a/twml/libtwml/BUILD b/twml/libtwml/BUILD new file mode 100644 index 000000000..c80b64b3b --- /dev/null +++ b/twml/libtwml/BUILD @@ -0,0 +1,8 @@ +python3_library( + name = "libtwml-python", + sources = ["libtwml/**/*.py"], + tags = [ + "no-mypy", + "bazel-compatible", + ], +) diff --git a/twml/libtwml/include/twml.h b/twml/libtwml/include/twml.h new file mode 100644 index 000000000..9d88cdc7b --- /dev/null +++ b/twml/libtwml/include/twml.h @@ -0,0 +1,21 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include diff --git a/twml/libtwml/include/twml/BatchPredictionRequest.h b/twml/libtwml/include/twml/BatchPredictionRequest.h new file mode 100644 index 000000000..6070ec045 --- /dev/null +++ b/twml/libtwml/include/twml/BatchPredictionRequest.h @@ -0,0 +1,45 @@ +#pragma once + +#ifdef __cplusplus + +#include +#include +#include + +namespace twml { + +template +class GenericBatchPredictionRequest { + static_assert(std::is_same::value || + std::is_same::value, + "RecordType has to be HashedDatarecord or DataRecord"); + public: + typedef typename RecordType::Reader Reader; + GenericBatchPredictionRequest(int numOfLabels=0, int numOfWeights=0): + m_common_features(), m_requests(), + num_labels(numOfLabels), num_weights(numOfWeights) + {} + + void decode(Reader &reader); + + std::vector& requests() { + return m_requests; + } + + RecordType& common() { + return m_common_features; + } + + private: + RecordType m_common_features; + std::vector m_requests; + int num_labels; + int num_weights; +}; + +using HashedBatchPredictionRequest = GenericBatchPredictionRequest; +using BatchPredictionRequest = GenericBatchPredictionRequest; + +} + +#endif diff --git a/twml/libtwml/include/twml/BatchPredictionResponse.h b/twml/libtwml/include/twml/BatchPredictionResponse.h new file mode 100644 index 000000000..b7e709464 --- /dev/null +++ b/twml/libtwml/include/twml/BatchPredictionResponse.h @@ -0,0 +1,58 @@ +#pragma once + +#include +#include +#include + +namespace twml { + + // Encodes a batch of model predictions as a list of Thrift DataRecord + // objects inside a Thrift BatchPredictionResponse object. Prediction + // values are continousFeatures inside each DataRecord. + // + // The BatchPredictionResponseWriter TensorFlow operator uses this class + // to determine the size of the output tensor to allocate. The operator + // then allocates memory for the output tensor and uses this class to + // write binary Thrift to the output tensor. + // + class BatchPredictionResponse { + private: + uint64_t batch_size_; + const Tensor &keys_; + const Tensor &values_; // prediction values (batch_size * num_keys) + const Tensor &dense_keys_; + const std::vector &dense_values_; + + inline uint64_t getBatchSize() { return batch_size_; } + inline bool hasContinuous() { return keys_.getNumDims() > 0; } + inline bool hasDenseTensors() { return dense_keys_.getNumDims() > 0; } + + inline uint64_t getPredictionSize() { + return values_.getNumDims() > 1 ? values_.getDim(1) : 1; + }; + + void encode(twml::ThriftWriter &thrift_writer); + + template + void serializePredictions(twml::ThriftWriter &thrift_writer); + + public: + // keys: 'continuousFeatures' prediction keys + // values: 'continuousFeatures' prediction values (batch_size * num_keys) + // dense_keys: 'tensors' prediction keys + // dense_values: 'tensors' prediction values (batch_size * num_keys) + BatchPredictionResponse( + const Tensor &keys, const Tensor &values, + const Tensor &dense_keys, const std::vector &dense_values); + + // Calculate the size of the Thrift encoded output (but do not encode). + // The BatchPredictionResponseWriter TensorFlow operator uses this value + // to allocate the output tensor. + uint64_t encodedSize(); + + // Write the BatchPredictionResponse as binary Thrift. The + // BatchPredictionResponseWriter operator uses this method to populate + // the output tensor. + void write(Tensor &result); + }; +} diff --git a/twml/libtwml/include/twml/BlockFormatReader.h b/twml/libtwml/include/twml/BlockFormatReader.h new file mode 100644 index 000000000..4c68458ba --- /dev/null +++ b/twml/libtwml/include/twml/BlockFormatReader.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace twml { +class BlockFormatReader { + private: + int record_size_; + long block_pos_; + long block_end_; + char classname_[1024]; + + int read_one_record_size(); + int read_int(); + int consume_marker(int scan); + int unpack_varint_i32(); + int unpack_tag_and_wiretype(uint32_t *tag, uint32_t *wiretype); + int unpack_string(char *out, uint64_t max_out_len); + + public: + BlockFormatReader(); + bool next(); + uint64_t current_size() const { return record_size_; } + + virtual uint64_t read_bytes(void *dest, int size, int count) = 0; +}; +} diff --git a/twml/libtwml/include/twml/BlockFormatWriter.h b/twml/libtwml/include/twml/BlockFormatWriter.h new file mode 100644 index 000000000..b9c496f40 --- /dev/null +++ b/twml/libtwml/include/twml/BlockFormatWriter.h @@ -0,0 +1,61 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +#ifndef PATH_MAX +#define PATH_MAX (8096) +#endif + +#ifdef __cplusplus +extern "C" { +#endif + + struct block_format_writer__; + typedef block_format_writer__ * block_format_writer; + +#ifdef __cplusplus +} +#endif + + +#ifdef __cplusplus +namespace twml { + class BlockFormatWriter { + private: + const char *file_name_; + FILE *outputfile_; + char temp_file_name_[PATH_MAX]; + int record_index_; + int records_per_block_; + + int pack_tag_and_wiretype(FILE *file, uint32_t tag, uint32_t wiretype); + int pack_varint_i32(FILE *file, int value); + int pack_string(FILE *file, const char *in, size_t in_len); + int write_int(FILE *file, int value); + + public: + BlockFormatWriter(const char *file_name, int record_per_block); + ~BlockFormatWriter(); + int write(const char *class_name, const char *record, int record_len) ; + int flush(); + block_format_writer getHandle(); + }; + + BlockFormatWriter *getBlockFormatWriter(block_format_writer w); +} //twml namespace +#endif + +#ifdef __cplusplus +extern "C" { +#endif +twml_err block_format_writer_create(block_format_writer *w, const char *file_name, int records_per_block); +twml_err block_format_write(block_format_writer w, const char *class_name, const char *record, int record_len); +twml_err block_format_flush(block_format_writer w); +twml_err block_format_writer_delete(const block_format_writer w); +#ifdef __cplusplus +} +#endif diff --git a/twml/libtwml/include/twml/DataRecord.h b/twml/libtwml/include/twml/DataRecord.h new file mode 100644 index 000000000..f39f1158b --- /dev/null +++ b/twml/libtwml/include/twml/DataRecord.h @@ -0,0 +1,108 @@ +#pragma once +#ifdef __cplusplus + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace twml { + +class DataRecordReader; + +class TWMLAPI DataRecord : public TensorRecord { +public: + typedef std::vector> SparseContinuousValueType; + typedef std::vector SparseBinaryValueType; + typedef Set BinaryFeatures; + typedef Map ContinuousFeatures; + typedef Map DiscreteFeatures; + typedef Map StringFeatures; + typedef Map SparseBinaryFeatures; + typedef Map SparseContinuousFeatures; + typedef Map> BlobFeatures; + +private: + BinaryFeatures m_binary; + ContinuousFeatures m_continuous; + DiscreteFeatures m_discrete; + StringFeatures m_string; + SparseBinaryFeatures m_sparsebinary; + SparseContinuousFeatures m_sparsecontinuous; + BlobFeatures m_blob; + + + std::vector m_labels; + std::vector m_weights; + + void addLabel(int64_t id, double label = 1); + void addWeight(int64_t id, double value); + +public: + typedef DataRecordReader Reader; + + DataRecord(int num_labels=0, int num_weights=0): + m_binary(), + m_continuous(), + m_discrete(), + m_string(), + m_sparsebinary(), + m_sparsecontinuous(), + m_blob(), + m_labels(num_labels, std::nanf("")), + m_weights(num_weights) { +#ifdef USE_DENSE_HASH + m_binary.set_empty_key(0); + m_continuous.set_empty_key(0); + m_discrete.set_empty_key(0); + m_string.set_empty_key(0); + m_sparsebinary.set_empty_key(0); + m_sparsecontinuous.set_empty_key(0); +#endif + m_binary.max_load_factor(0.5); + m_continuous.max_load_factor(0.5); + m_discrete.max_load_factor(0.5); + m_string.max_load_factor(0.5); + m_sparsebinary.max_load_factor(0.5); + m_sparsecontinuous.max_load_factor(0.5); + } + + const BinaryFeatures &getBinary() const { return m_binary; } + const ContinuousFeatures &getContinuous() const { return m_continuous; } + const DiscreteFeatures &getDiscrete() const { return m_discrete; } + const StringFeatures &getString() const { return m_string; } + const SparseBinaryFeatures &getSparseBinary() const { return m_sparsebinary; } + const SparseContinuousFeatures &getSparseContinuous() const { return m_sparsecontinuous; } + const BlobFeatures &getBlob() const { return m_blob; } + + const std::vector &labels() const { return m_labels; } + const std::vector &weights() const { return m_weights; } + + // used by DataRecordWriter + template + void addContinuous(std::vector feature_ids, std::vector values) { + for (size_t i = 0; i < feature_ids.size(); ++i){ + m_continuous[feature_ids[i]] = values[i]; + } + } + + template + void addContinuous(const int64_t *keys, uint64_t num_keys, T *values) { + for (size_t i = 0; i < num_keys; ++i){ + m_continuous[keys[i]] = values[i]; + } + } + + void decode(DataRecordReader &reader); + void clear(); + friend class DataRecordReader; +}; + +} +#endif diff --git a/twml/libtwml/include/twml/DataRecordReader.h b/twml/libtwml/include/twml/DataRecordReader.h new file mode 100644 index 000000000..0ef8e64ff --- /dev/null +++ b/twml/libtwml/include/twml/DataRecordReader.h @@ -0,0 +1,61 @@ +#pragma once +#ifdef __cplusplus + +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace twml { + +class TWMLAPI DataRecordReader : public TensorRecordReader { + +private: + typedef Map KeyMap_t; + KeyMap_t *m_keep_map; + KeyMap_t *m_labels_map; + KeyMap_t *m_weights_map; + +public: + bool keepKey (const int64_t &key, int64_t &code); + bool isLabel (const int64_t &key, int64_t &code); + bool isWeight (const int64_t &key, int64_t &code); + void readBinary (const int feature_type , DataRecord *record); + void readContinuous (const int feature_type , DataRecord *record); + void readDiscrete (const int feature_type , DataRecord *record); + void readString (const int feature_type , DataRecord *record); + void readSparseBinary (const int feature_type , DataRecord *record); + void readSparseContinuous (const int feature_type , DataRecord *record); + void readBlob (const int feature_type , DataRecord *record); + + DataRecordReader() : + TensorRecordReader(nullptr), + m_keep_map(nullptr), + m_labels_map(nullptr), + m_weights_map(nullptr) + {} + + // Using a template instead of int64_t because tensorflow implements int64 based on compiler. + void setKeepMap(KeyMap_t *keep_map) { + m_keep_map = keep_map; + } + + void setLabelsMap(KeyMap_t *labels_map) { + m_labels_map = labels_map; + } + + void setWeightsMap(KeyMap_t *weights_map) { + m_weights_map = weights_map; + } + + void setDecodeMode(int64_t mode) {} +}; + +} +#endif diff --git a/twml/libtwml/include/twml/DataRecordWriter.h b/twml/libtwml/include/twml/DataRecordWriter.h new file mode 100644 index 000000000..6b330d323 --- /dev/null +++ b/twml/libtwml/include/twml/DataRecordWriter.h @@ -0,0 +1,39 @@ +#pragma once +#ifdef __cplusplus + +#include +#include +#include + +namespace twml { + +// Encodes DataRecords as binary Thrift. BatchPredictionResponse +// uses this class to encode prediction responses through our +// TensorFlow response writer operator. +class TWMLAPI DataRecordWriter { + private: + uint32_t m_records_written; + twml::ThriftWriter &m_thrift_writer; + twml::TensorRecordWriter m_tensor_writer; + + void writeBinary(twml::DataRecord &record); + void writeContinuous(twml::DataRecord &record); + void writeDiscrete(twml::DataRecord &record); + void writeString(twml::DataRecord &record); + void writeSparseBinaryFeatures(twml::DataRecord &record); + void writeSparseContinuousFeatures(twml::DataRecord &record); + void writeBlobFeatures(twml::DataRecord &record); + void writeDenseTensors(twml::DataRecord &record); + + public: + DataRecordWriter(twml::ThriftWriter &thrift_writer): + m_records_written(0), + m_thrift_writer(thrift_writer), + m_tensor_writer(twml::TensorRecordWriter(thrift_writer)) { } + + uint32_t getRecordsWritten(); + uint64_t write(twml::DataRecord &record); +}; + +} +#endif diff --git a/twml/libtwml/include/twml/Error.h b/twml/libtwml/include/twml/Error.h new file mode 100644 index 000000000..89307d214 --- /dev/null +++ b/twml/libtwml/include/twml/Error.h @@ -0,0 +1,48 @@ +#pragma once +#include + +#ifdef __cplusplus +#include +#include +#include +#include + +namespace twml { + +class Error : public std::runtime_error { + private: + twml_err m_err; + public: + Error(twml_err err, const std::string &msg) : + std::runtime_error(msg), m_err(err) + { + } + + twml_err err() const + { + return m_err; + } +}; + +class ThriftInvalidField: public twml::Error { + public: + ThriftInvalidField(int16_t field_id, const std::string& func) : + Error(TWML_ERR_THRIFT, + "Found invalid field (" + std::to_string(field_id) + + ") while reading thrift [" + func + "]") + { + } +}; + +class ThriftInvalidType: public twml::Error { + public: + ThriftInvalidType(uint8_t type_id, const std::string& func, const std::string type) : + Error(TWML_ERR_THRIFT, + "Found invalid type (" + std::to_string(type_id) + + ") while reading thrift [" + func + "::" + type + "]") + { + } +}; + +} +#endif diff --git a/twml/libtwml/include/twml/HashedDataRecord.h b/twml/libtwml/include/twml/HashedDataRecord.h new file mode 100644 index 000000000..de63c4dc7 --- /dev/null +++ b/twml/libtwml/include/twml/HashedDataRecord.h @@ -0,0 +1,70 @@ +#pragma once +#ifdef __cplusplus + +#include +#include + +#include +#include +#include + +namespace twml { + +class HashedDataRecordReader; + +class TWMLAPI HashedDataRecord : public TensorRecord { + public: + typedef HashedDataRecordReader Reader; + + HashedDataRecord(int num_labels=0, int num_weights=0): + m_keys(), + m_transformed_keys(), + m_values(), + m_codes(), + m_types(), + m_labels(num_labels, std::nanf("")), + m_weights(num_weights) {} + + void decode(HashedDataRecordReader &reader); + + const std::vector &keys() const { return m_keys; } + const std::vector &transformed_keys() const { return m_transformed_keys; } + const std::vector &values() const { return m_values; } + const std::vector &codes() const { return m_codes; } + const std::vector &types() const { return m_types; } + + const std::vector &labels() const { return m_labels; } + const std::vector &weights() const { return m_weights; } + + void clear(); + + uint64_t totalSize() const { return m_keys.size(); } + + void extendSize(int delta_size) { + int count = m_keys.size() + delta_size; + m_keys.reserve(count); + m_transformed_keys.reserve(count); + m_values.reserve(count); + m_codes.reserve(count); + m_types.reserve(count); + } + + private: + std::vector m_keys; + std::vector m_transformed_keys; + std::vector m_values; + std::vector m_codes; + std::vector m_types; + + std::vector m_labels; + std::vector m_weights; + + void addKey(int64_t key, int64_t transformed_key, int64_t code, uint8_t type, double value=1); + void addLabel(int64_t id, double value = 1); + void addWeight(int64_t id, double value); + + friend class HashedDataRecordReader; +}; + +} +#endif \ No newline at end of file diff --git a/twml/libtwml/include/twml/HashedDataRecordReader.h b/twml/libtwml/include/twml/HashedDataRecordReader.h new file mode 100644 index 000000000..5470eb5c8 --- /dev/null +++ b/twml/libtwml/include/twml/HashedDataRecordReader.h @@ -0,0 +1,70 @@ +#pragma once +#ifdef __cplusplus + +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace twml { + +enum class DecodeMode: int64_t +{ + hash_valname = 0, + hash_fname_and_valname = 1, +}; + +class TWMLAPI HashedDataRecordReader : public TensorRecordReader { +private: + typedef Map KeyMap_t; + KeyMap_t *m_keep_map; + KeyMap_t *m_labels_map; + KeyMap_t *m_weights_map; + DecodeMode m_decode_mode; + +public: + bool keepId (const int64_t &key, int64_t &code); + bool isLabel (const int64_t &key, int64_t &code); + bool isWeight (const int64_t &key, int64_t &code); + void readBinary (const int feature_type , HashedDataRecord *record); + void readContinuous (const int feature_type , HashedDataRecord *record); + void readDiscrete (const int feature_type , HashedDataRecord *record); + void readString (const int feature_type , HashedDataRecord *record); + void readSparseBinary (const int feature_type , HashedDataRecord *record); + void readSparseContinuous (const int feature_type , HashedDataRecord *record); + void readBlob (const int feature_type , HashedDataRecord *record); + + HashedDataRecordReader() : + TensorRecordReader(nullptr), + m_keep_map(nullptr), + m_labels_map(nullptr), + m_weights_map(nullptr), + m_decode_mode(DecodeMode::hash_valname) + {} + + // Using a template instead of int64_t because tensorflow implements int64 based on compiler. + void setKeepMap(KeyMap_t *keep_map) { + m_keep_map = keep_map; + } + + void setLabelsMap(KeyMap_t *labels_map) { + m_labels_map = labels_map; + } + + void setWeightsMap(KeyMap_t *weights_map) { + m_weights_map = weights_map; + } + + void setDecodeMode(int64_t mode) { + m_decode_mode = static_cast(mode); + } +}; + +} +#endif diff --git a/twml/libtwml/include/twml/Hashmap.h b/twml/libtwml/include/twml/Hashmap.h new file mode 100644 index 000000000..59314236b --- /dev/null +++ b/twml/libtwml/include/twml/Hashmap.h @@ -0,0 +1,110 @@ +#pragma once +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + typedef void * twml_hashmap; + typedef int64_t tw_hash_key_t; + typedef int64_t tw_hash_val_t; +#ifdef __cplusplus +} +#endif + +#ifdef __cplusplus +namespace twml { + + typedef tw_hash_key_t HashKey_t; + typedef tw_hash_val_t HashVal_t; + + class HashMap { + private: + twml_hashmap m_hashmap; + + public: + HashMap(); + ~HashMap(); + + // Disable copy constructor and assignment + // TODO: Fix this after retain and release are added to twml_hashmap + HashMap(const HashMap &other) = delete; + HashMap& operator=(const HashMap &other) = delete; + + void clear(); + uint64_t size() const; + int8_t insert(const HashKey_t key); + int8_t insert(const HashKey_t key, const HashVal_t val); + void remove(const HashKey_t key); + int8_t get(HashVal_t &val, const HashKey_t key) const; + + void insert(Tensor &mask, const Tensor keys); + void insert(Tensor &mask, const Tensor keys, const Tensor vals); + void remove(const Tensor keys); + void get(Tensor &mask, Tensor &vals, const Tensor keys) const; + + void getInplace(Tensor &mask, Tensor &keys_vals) const; + void toTensors(Tensor &keys, Tensor &vals) const; + }; +} +#endif + +#ifdef __cplusplus +extern "C" { +#endif + + + TWMLAPI twml_err twml_hashmap_create(twml_hashmap *hashmap); + + TWMLAPI twml_err twml_hashmap_clear(const twml_hashmap hashmap); + + TWMLAPI twml_err twml_hashmap_get_size(uint64_t *size, const twml_hashmap hashmap); + + TWMLAPI twml_err twml_hashmap_delete(const twml_hashmap hashmap); + + // insert, get, remove single key / value + TWMLAPI twml_err twml_hashmap_insert_key(int8_t *mask, + const twml_hashmap hashmap, + const tw_hash_key_t key); + + TWMLAPI twml_err twml_hashmap_insert_key_and_value(int8_t *mask, twml_hashmap hashmap, + const tw_hash_key_t key, + const tw_hash_val_t val); + + TWMLAPI twml_err twml_hashmap_remove_key(const twml_hashmap hashmap, + const tw_hash_key_t key); + + TWMLAPI twml_err twml_hashmap_get_value(int8_t *mask, tw_hash_val_t *val, + const twml_hashmap hashmap, + const tw_hash_key_t key); + + TWMLAPI twml_err twml_hashmap_insert_keys(twml_tensor masks, + const twml_hashmap hashmap, + const twml_tensor keys); + + // insert, get, remove tensors of keys / values + TWMLAPI twml_err twml_hashmap_insert_keys_and_values(twml_tensor masks, + twml_hashmap hashmap, + const twml_tensor keys, + const twml_tensor vals); + + TWMLAPI twml_err twml_hashmap_remove_keys(const twml_hashmap hashmap, + const twml_tensor keys); + + TWMLAPI twml_err twml_hashmap_get_values(twml_tensor masks, + twml_tensor vals, + const twml_hashmap hashmap, + const twml_tensor keys); + + TWMLAPI twml_err twml_hashmap_get_values_inplace(twml_tensor masks, + twml_tensor keys_vals, + const twml_hashmap hashmap); + + TWMLAPI twml_err twml_hashmap_to_tensors(twml_tensor keys, + twml_tensor vals, + const twml_hashmap hashmap); +#ifdef __cplusplus +} +#endif diff --git a/twml/libtwml/include/twml/RawTensor.h b/twml/libtwml/include/twml/RawTensor.h new file mode 100644 index 000000000..571966743 --- /dev/null +++ b/twml/libtwml/include/twml/RawTensor.h @@ -0,0 +1,92 @@ +#pragma once +#include +#include + +#ifdef __cplusplus +namespace twml { + +// This class contains the raw pointers to tensors coming from thrift object. +class TWMLAPI RawTensor : public Tensor +{ +private: + bool m_is_big_endian; + uint64_t m_raw_length; +public: + + RawTensor() {} + + RawTensor(void *data, const std::vector &dims, + const std::vector &strides, twml_type type, bool is_big_endian, uint64_t length) + : Tensor(data, dims, strides, type), m_is_big_endian(is_big_endian), m_raw_length(length) {} + + bool is_big_endian() const { + return m_is_big_endian; + } + + uint64_t getRawLength() const { + return m_raw_length; + } + + // Extracts a slice from a tensor at idx0 along dimension 0 + // Used in BatchPredictionResponse to write each slice in separate records + RawTensor getSlice(uint64_t idx0) const { + void *slice = nullptr; + uint64_t raw_length = 0; + + if (getType() == TWML_TYPE_STRING) { + raw_length = getStride(0); + std::string *data = const_cast(static_cast(getData())); + slice = static_cast(data + raw_length * idx0); + } else { + raw_length = getStride(0) * getSizeOf(getType()); + char *data = const_cast(static_cast(getData())); + slice = static_cast(data + raw_length * idx0); + } + + std::vector dims, strides; + for (int i = 1; i < getNumDims(); i++) { + dims.push_back(getDim(i)); + strides.push_back(getStride(i)); + } + + return RawTensor(slice, dims, strides, getType(), m_is_big_endian, raw_length); + } +}; + +// Wrapper class around RawTensor to hold sparse tensors. +class TWMLAPI RawSparseTensor +{ +private: + RawTensor m_indices; + RawTensor m_values; + std::vector m_dense_shape; + +public: + + RawSparseTensor() { + } + + RawSparseTensor(const RawTensor &indices_, const RawTensor &values_, + const std::vector &dense_shape_) : + m_indices(indices_), m_values(values_), m_dense_shape(dense_shape_) + { + if (m_indices.getType() != TWML_TYPE_INT64) { + throw twml::Error(TWML_ERR_TYPE, "Indices of Sparse Tensor must be of type int64"); + } + } + + const RawTensor &indices() const { + return m_indices; + } + + const RawTensor &values() const { + return m_values; + } + + const std::vector& denseShape() const { + return m_dense_shape; + } +}; + +} +#endif diff --git a/twml/libtwml/include/twml/Tensor.h b/twml/libtwml/include/twml/Tensor.h new file mode 100644 index 000000000..774474403 --- /dev/null +++ b/twml/libtwml/include/twml/Tensor.h @@ -0,0 +1,82 @@ +#pragma once +#include + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + + struct twml_tensor__; + typedef twml_tensor__ * twml_tensor; + +#ifdef __cplusplus +} +#endif + +#ifdef __cplusplus +namespace twml { + +class TWMLAPI Tensor +{ +private: + twml_type m_type; + void *m_data; + std::vector m_dims; + std::vector m_strides; + +public: + Tensor() {} + Tensor(void *data, int ndims, const uint64_t *dims, const uint64_t *strides, twml_type type); + Tensor(void *data, const std::vector &dims, const std::vector &strides, twml_type type); + + const std::vector& getDims() const { + return m_dims; + } + + int getNumDims() const; + uint64_t getDim(int dim) const; + uint64_t getStride(int dim) const; + uint64_t getNumElements() const; + twml_type getType() const; + + twml_tensor getHandle(); + const twml_tensor getHandle() const; + + template T *getData(); + template const T *getData() const; +}; + +TWMLAPI std::string getTypeName(twml_type type); +TWMLAPI const Tensor *getConstTensor(const twml_tensor t); +TWMLAPI Tensor *getTensor(twml_tensor t); +TWMLAPI uint64_t getSizeOf(twml_type type); + +} +#endif + +#ifdef __cplusplus +extern "C" { +#endif + TWMLAPI twml_err twml_tensor_create(twml_tensor *tensor, void *data, + int ndims, uint64_t *dims, + uint64_t *strides, twml_type type); + + TWMLAPI twml_err twml_tensor_delete(const twml_tensor tensor); + + TWMLAPI twml_err twml_tensor_get_type(twml_type *type, const twml_tensor tensor); + + TWMLAPI twml_err twml_tensor_get_data(void **data, const twml_tensor tensor); + + TWMLAPI twml_err twml_tensor_get_dim(uint64_t *dim, const twml_tensor tensor, int id); + + TWMLAPI twml_err twml_tensor_get_num_dims(int *ndims, const twml_tensor tensor); + + TWMLAPI twml_err twml_tensor_get_num_elements(uint64_t *nelements, const twml_tensor tensor); + + TWMLAPI twml_err twml_tensor_get_stride(uint64_t *stride, const twml_tensor tensor, int id); +#ifdef __cplusplus +} +#endif diff --git a/twml/libtwml/include/twml/TensorRecord.h b/twml/libtwml/include/twml/TensorRecord.h new file mode 100644 index 000000000..d128cfdce --- /dev/null +++ b/twml/libtwml/include/twml/TensorRecord.h @@ -0,0 +1,47 @@ +#pragma once +#ifdef __cplusplus + +#include +#include + +#include +#include + +namespace twml { + +class TensorRecordReader; + +// A class containing the data from TensorRecord. +// - This serves as the base class from which DataRecord and HashedDataRecord are inherited. +class TWMLAPI TensorRecord { +public: + typedef std::unordered_map RawTensors; + typedef std::unordered_map RawSparseTensors; + +private: + RawTensors m_tensors; + RawSparseTensors m_sparse_tensors; + +public: + + const RawTensors &getRawTensors() { + return m_tensors; + } + + const RawTensor& getRawTensor(int64_t id) const { + return m_tensors.at(id); + } + + const RawSparseTensor& getRawSparseTensor(int64_t id) const { + return m_sparse_tensors.at(id); + } + + void addRawTensor(int64_t id, const RawTensor &tensor) { + m_tensors.emplace(id, tensor); + } + + friend class TensorRecordReader; +}; + +} +#endif diff --git a/twml/libtwml/include/twml/TensorRecordReader.h b/twml/libtwml/include/twml/TensorRecordReader.h new file mode 100644 index 000000000..3a62bd885 --- /dev/null +++ b/twml/libtwml/include/twml/TensorRecordReader.h @@ -0,0 +1,34 @@ +#pragma once +#ifdef __cplusplus + +#include +#include +#include + +#include + +#include +#include +#include + +namespace twml { + +// Class that parses the thrift objects as defined in tensor.thrift +class TWMLAPI TensorRecordReader : public ThriftReader { + + std::vector readShape(); + template RawTensor readTypedTensor(); + RawTensor readRawTypedTensor(); + RawTensor readStringTensor(); + RawTensor readGeneralTensor(); + RawSparseTensor readCOOSparseTensor(); + +public: + void readTensor(const int feature_type, TensorRecord *record); + void readSparseTensor(const int feature_type, TensorRecord *record); + + TensorRecordReader(const uint8_t *buffer) : ThriftReader(buffer) {} +}; + +} +#endif diff --git a/twml/libtwml/include/twml/TensorRecordWriter.h b/twml/libtwml/include/twml/TensorRecordWriter.h new file mode 100644 index 000000000..d8b7c3dbf --- /dev/null +++ b/twml/libtwml/include/twml/TensorRecordWriter.h @@ -0,0 +1,35 @@ +#pragma once +#ifdef __cplusplus + +#include +#include + +namespace twml { + +// Encodes tensors as DataRecord/TensorRecord-compatible Thrift. +// DataRecordWriter relies on this class to encode the tensor fields. +class TWMLAPI TensorRecordWriter { + +private: + uint32_t m_records_written; + twml::ThriftWriter &m_thrift_writer; + + void writeTensor(const RawTensor &tensor); + void writeRawTensor(const RawTensor &tensor); + +public: + TensorRecordWriter(twml::ThriftWriter &thrift_writer): + m_records_written(0), + m_thrift_writer(thrift_writer) { } + + uint32_t getRecordsWritten(); + + // Caller (usually DataRecordWriter) must precede with struct header field + // like thrift_writer.writeStructFieldHeader(TTYPE_MAP, DR_GENERAL_TENSOR) + // + // All tensors written as RawTensors except for StringTensors + uint64_t write(twml::TensorRecord &record); +}; + +} +#endif diff --git a/twml/libtwml/include/twml/ThriftReader.h b/twml/libtwml/include/twml/ThriftReader.h new file mode 100644 index 000000000..25c83ea29 --- /dev/null +++ b/twml/libtwml/include/twml/ThriftReader.h @@ -0,0 +1,56 @@ +#pragma once + +#ifdef __cplusplus + +#include +#include +#include +#include + +namespace twml { + +class ThriftReader { + protected: + const uint8_t *m_buffer; + + public: + + ThriftReader(const uint8_t *buffer): m_buffer(buffer) {} + + const uint8_t *getBuffer() { return m_buffer; } + + void setBuffer(const uint8_t *buffer) { m_buffer = buffer; } + + template T readDirect() { + T val; + memcpy(&val, m_buffer, sizeof(T)); + m_buffer += sizeof(T); + return val; + } + + template void skip() { + m_buffer += sizeof(T); + } + + void skipLength(size_t length) { + m_buffer += length; + } + + uint8_t readByte(); + int16_t readInt16(); + int32_t readInt32(); + int64_t readInt64(); + double readDouble(); + + template inline + int32_t getRawBuffer(const uint8_t **begin) { + int32_t length = readInt32(); + *begin = m_buffer; + skipLength(length * sizeof(T)); + return length; + } + +}; + +} +#endif diff --git a/twml/libtwml/include/twml/ThriftWriter.h b/twml/libtwml/include/twml/ThriftWriter.h new file mode 100644 index 000000000..1216415b0 --- /dev/null +++ b/twml/libtwml/include/twml/ThriftWriter.h @@ -0,0 +1,59 @@ +#pragma once + +#ifdef __cplusplus + +#include +#include +#include +#include + +namespace twml { + +// A low-level binary Thrift writer that can also compute output size +// in dry run mode without copying memory. See also https://git.io/vNPiv +// +// WARNING: Users of this class are responsible for generating valid Thrift +// by following the Thrift binary protocol (https://git.io/vNPiv). +class TWMLAPI ThriftWriter { + protected: + bool m_dry_run; + uint8_t *m_buffer; + size_t m_buffer_size; + size_t m_bytes_written; + + template inline uint64_t write(T val); + + public: + // buffer: Memory to write the binary Thrift to. + // buffer_size: Length of the buffer. + // dry_run: If true, just count bytes 'written' but do not copy memory. + // If false, write binary Thrift to the buffer normally. + // Useful to determine output size for TensorFlow allocations. + ThriftWriter(uint8_t *buffer, size_t buffer_size, bool dry_run = false) : + m_dry_run(dry_run), + m_buffer(buffer), + m_buffer_size(buffer_size), + m_bytes_written(0) {} + + // total bytes written to the buffer since object creation + uint64_t getBytesWritten(); + + // encode headers and values into the buffer + uint64_t writeStructFieldHeader(int8_t field_type, int16_t field_id); + uint64_t writeStructStop(); + uint64_t writeListHeader(int8_t element_type, int32_t num_elems); + uint64_t writeMapHeader(int8_t key_type, int8_t val_type, int32_t num_elems); + uint64_t writeDouble(double val); + uint64_t writeInt8(int8_t val); + uint64_t writeInt16(int16_t val); + uint64_t writeInt32(int32_t val); + uint64_t writeInt64(int64_t val); + uint64_t writeBinary(const uint8_t *bytes, int32_t num_bytes); + // clients expect UTF-8-encoded strings per the Thrift protocol + // (often this is just used to send bytes, not real strings though) + uint64_t writeString(std::string str); + uint64_t writeBool(bool val); +}; + +} +#endif diff --git a/twml/libtwml/include/twml/Type.h b/twml/libtwml/include/twml/Type.h new file mode 100644 index 000000000..8b460c812 --- /dev/null +++ b/twml/libtwml/include/twml/Type.h @@ -0,0 +1,69 @@ +#pragma once +#include +#include +#include + +#ifdef __cplusplus +namespace twml { + + template struct Type; + + template<> struct Type + { + enum { + type = TWML_TYPE_FLOAT, + }; + }; + + template<> struct Type + { + enum { + type = TWML_TYPE_STRING, + }; + }; + + template<> struct Type + { + enum { + type = TWML_TYPE_DOUBLE, + }; + }; + + template<> struct Type + { + enum { + type = TWML_TYPE_INT64, + }; + }; + + template<> struct Type + { + enum { + type = TWML_TYPE_INT32, + }; + }; + + template<> struct Type + { + enum { + type = TWML_TYPE_INT8, + }; + }; + + template<> struct Type + { + enum { + type = TWML_TYPE_UINT8, + }; + }; + + + template<> struct Type + { + enum { + type = TWML_TYPE_BOOL, + }; + }; + +} +#endif diff --git a/twml/libtwml/include/twml/common.h b/twml/libtwml/include/twml/common.h new file mode 100644 index 000000000..c3a2e9aee --- /dev/null +++ b/twml/libtwml/include/twml/common.h @@ -0,0 +1,42 @@ +#ifndef TWML_LIBTWML_INCLUDE_TWML_COMMON_H_ +#define TWML_LIBTWML_INCLUDE_TWML_COMMON_H_ + +#define USE_ABSEIL_HASH 1 + +#if defined(USE_ABSEIL_HASH) +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#elif defined(USE_DENSE_HASH) +#include +#include +#else +#include +#include +#endif // USE_ABSEIL_HASH + + +namespace twml { +#if defined(USE_ABSEIL_HASH) + template + using Map = absl::flat_hash_map; + + template + using Set = absl::flat_hash_set; +#elif defined(USE_DENSE_HASH) +// Do not use this unless an proper empty key can be found. + template + using Map = google::dense_hash_map; + + template + using Set = google::dense_hash_set; +#else + template + using Map = std::unordered_map; + + template + using Set = std::unordered_set; +#endif // USE_DENSE_HASH + +} // namespace twml + +#endif // TWML_LIBTWML_INCLUDE_TWML_COMMON_H_ \ No newline at end of file diff --git a/twml/libtwml/include/twml/defines.h b/twml/libtwml/include/twml/defines.h new file mode 100644 index 000000000..e7f7d138d --- /dev/null +++ b/twml/libtwml/include/twml/defines.h @@ -0,0 +1,36 @@ +#pragma once +#include +#ifdef __cplusplus +extern "C" { +#endif + typedef enum { + TWML_TYPE_FLOAT32 = 1, + TWML_TYPE_FLOAT64 = 2, + TWML_TYPE_INT32 = 3, + TWML_TYPE_INT64 = 4, + TWML_TYPE_INT8 = 5, + TWML_TYPE_UINT8 = 6, + TWML_TYPE_BOOL = 7, + TWML_TYPE_STRING = 8, + TWML_TYPE_FLOAT = TWML_TYPE_FLOAT32, + TWML_TYPE_DOUBLE = TWML_TYPE_FLOAT64, + TWML_TYPE_UNKNOWN = -1, + } twml_type; + + typedef enum { + TWML_ERR_NONE = 1000, + TWML_ERR_SIZE = 1001, + TWML_ERR_TYPE = 1002, + TWML_ERR_THRIFT = 1100, + TWML_ERR_IO = 1200, + TWML_ERR_UNKNOWN = 1999, + } twml_err; +#ifdef __cplusplus +} +#endif + +#define TWMLAPI __attribute__((visibility("default"))) + +#ifndef TWML_INDEX_BASE +#define TWML_INDEX_BASE 0 +#endif diff --git a/twml/libtwml/include/twml/discretizer_impl.h b/twml/libtwml/include/twml/discretizer_impl.h new file mode 100644 index 000000000..587bde458 --- /dev/null +++ b/twml/libtwml/include/twml/discretizer_impl.h @@ -0,0 +1,22 @@ +#pragma once +#include +#include +#include + +#ifdef __cplusplus +namespace twml { + TWMLAPI void discretizerInfer( + Tensor &output_keys, + Tensor &output_vals, + const Tensor &input_ids, + const Tensor &input_vals, + const Tensor &bin_ids, + const Tensor &bin_vals, + const Tensor &feature_offsets, + int output_bits, + const Map &ID_to_index, + int start_compute, + int end_compute, + int output_start); +} // namespace twml +#endif diff --git a/twml/libtwml/include/twml/functions.h b/twml/libtwml/include/twml/functions.h new file mode 100644 index 000000000..c23680cac --- /dev/null +++ b/twml/libtwml/include/twml/functions.h @@ -0,0 +1,26 @@ +#pragma once +#include +#include + +#ifdef __cplusplus +namespace twml { + + // Adding these as an easy way to test the wrappers + TWMLAPI void add1(Tensor &output, const Tensor input); + TWMLAPI void copy(Tensor &output, const Tensor input); + TWMLAPI int64_t featureId(const std::string &feature); +} +#endif + +#ifdef __cplusplus +extern "C" { +#endif + + // Adding these as an easy way to test the wrappers + TWMLAPI twml_err twml_add1(twml_tensor output, const twml_tensor input); + TWMLAPI twml_err twml_copy(twml_tensor output, const twml_tensor input); + TWMLAPI twml_err twml_get_feature_id(int64_t *result, const uint64_t len, const char *str); + +#ifdef __cplusplus +} +#endif diff --git a/twml/libtwml/include/twml/hashing_discretizer_impl.h b/twml/libtwml/include/twml/hashing_discretizer_impl.h new file mode 100644 index 000000000..a04efb7e0 --- /dev/null +++ b/twml/libtwml/include/twml/hashing_discretizer_impl.h @@ -0,0 +1,22 @@ +#pragma once +#include +#include +#include +#include + +#ifdef __cplusplus +namespace twml { + TWMLAPI void hashDiscretizerInfer( + Tensor &output_keys, + Tensor &output_vals, + const Tensor &input_ids, + const Tensor &input_vals, + int n_bin, + const Tensor &bin_vals, + int output_bits, + const Map &ID_to_index, + int start_compute, + int end_compute, + int64_t options); +} // namespace twml +#endif diff --git a/twml/libtwml/include/twml/io/IOError.h b/twml/libtwml/include/twml/io/IOError.h new file mode 100644 index 000000000..867ab44df --- /dev/null +++ b/twml/libtwml/include/twml/io/IOError.h @@ -0,0 +1,45 @@ +#pragma once + +#include + +namespace twml { +namespace io { + +class IOError : public twml::Error { + public: + enum Status { + OUT_OF_RANGE = 1, + WRONG_MAGIC = 2, + WRONG_HEADER = 3, + ERROR_HEADER_CHECKSUM = 4, + INVALID_METHOD = 5, + USING_RESERVED = 6, + ERROR_HEADER_EXTRA_FIELD_CHECKSUM = 7, + CANT_FIT_OUTPUT = 8, + SPLIT_FILE = 9, + BLOCK_SIZE_TOO_LARGE = 10, + SOURCE_LARGER_THAN_DESTINATION = 11, + DESTINATION_LARGER_THAN_CAPACITY = 12, + HEADER_FLAG_MISMATCH = 13, + NOT_ENOUGH_INPUT = 14, + ERROR_SOURCE_BLOCK_CHECKSUM = 15, + COMPRESSED_DATA_VIOLATION = 16, + ERROR_DESTINATION_BLOCK_CHECKSUM = 17, + EMPTY_RECORD = 18, + MALFORMED_MEMORY_RECORD = 19, + UNSUPPORTED_OUTPUT_TYPE = 20, + OTHER_ERROR + }; + + IOError(Status status); + + Status status() const { + return m_status; + } + + private: + Status m_status; +}; + +} +} diff --git a/twml/libtwml/include/twml/optim.h b/twml/libtwml/include/twml/optim.h new file mode 100644 index 000000000..d0a2df4ef --- /dev/null +++ b/twml/libtwml/include/twml/optim.h @@ -0,0 +1,51 @@ +#pragma once +#include +#include + +#ifdef __cplusplus +namespace twml { + TWMLAPI void linearInterpolation( + Tensor output, + const Tensor input, + const Tensor xs, + const Tensor ys); + + TWMLAPI void nearestInterpolation( + Tensor output, + const Tensor input, + const Tensor xs, + const Tensor ys); + + TWMLAPI void mdlInfer( + Tensor &output_keys, + Tensor &output_vals, + const Tensor &input_keys, + const Tensor &input_vals, + const Tensor &bin_ids, + const Tensor &bin_vals, + const Tensor &feature_offsets, + bool return_bin_indices = false); +} +#endif + +#ifdef __cplusplus +extern "C" { +#endif + TWMLAPI twml_err twml_optim_nearest_interpolation( + twml_tensor output, + const twml_tensor input, + const twml_tensor xs, + const twml_tensor ys); + + TWMLAPI twml_err twml_optim_mdl_infer( + twml_tensor output_keys, + twml_tensor output_vals, + const twml_tensor input_keys, + const twml_tensor input_vals, + const twml_tensor bin_ids, + const twml_tensor bin_vals, + const twml_tensor feature_offsets, + const bool return_bin_indices = false); +#ifdef __cplusplus +} +#endif diff --git a/twml/libtwml/include/twml/utilities.h b/twml/libtwml/include/twml/utilities.h new file mode 100644 index 000000000..a30b44aff --- /dev/null +++ b/twml/libtwml/include/twml/utilities.h @@ -0,0 +1,18 @@ +#pragma once +#ifdef __cplusplus +namespace twml { + +inline int64_t mixDiscreteIdAndValue(int64_t key, int64_t value) { + key ^= ((17LL + value) * 2654435761LL); + return key; +} + +inline int64_t mixStringIdAndValue(int64_t key, int32_t str_len, const uint8_t *str) { + int32_t hash = 0; + for (int32_t i = 0; i < str_len; i++) { + hash = (31 * hash) + (int32_t)str[i]; + } + return key ^ hash; +} +} +#endif \ No newline at end of file diff --git a/twml/libtwml/setup.cfg b/twml/libtwml/setup.cfg new file mode 100644 index 000000000..d5253c179 --- /dev/null +++ b/twml/libtwml/setup.cfg @@ -0,0 +1,9 @@ +[bdist_wheel] +universal=1 + +[build] +build-lib=build_dir +build-temp=build_dir + +[bdist] +bdist-base=build_dir diff --git a/twml/libtwml/setup.py b/twml/libtwml/setup.py new file mode 100644 index 000000000..2dcfa105d --- /dev/null +++ b/twml/libtwml/setup.py @@ -0,0 +1,12 @@ +""" +libtwml setup.py module +""" +from setuptools import setup, find_packages + +setup( + name='libtwml', + version='2.0', + description="Tensorflow C++ ops for twml", + packages=find_packages(), + data_files=[('', ['libtwml_tf.so'])], +) diff --git a/twml/libtwml/src/lib/BatchPredictionRequest.cpp b/twml/libtwml/src/lib/BatchPredictionRequest.cpp new file mode 100644 index 000000000..cca8d6545 --- /dev/null +++ b/twml/libtwml/src/lib/BatchPredictionRequest.cpp @@ -0,0 +1,52 @@ +#include "internal/thrift.h" +#include "internal/error.h" + +#include +#include +#include +#include + +#include +#include +#include + +namespace twml { + +template +void GenericBatchPredictionRequest::decode(Reader &reader) { + uint8_t feature_type = reader.readByte(); + while (feature_type != TTYPE_STOP) { + int16_t field_id = reader.readInt16(); + + switch (field_id) { + case 1: { + CHECK_THRIFT_TYPE(feature_type, TTYPE_LIST, "list"); + CHECK_THRIFT_TYPE(reader.readByte(), TTYPE_STRUCT, "list_element"); + + int32_t length = reader.readInt32(); + m_requests.resize(length, RecordType(this->num_labels, this->num_weights)); + for (auto &request : m_requests) { + request.decode(reader); + } + + break; + } + case 2: { + CHECK_THRIFT_TYPE(feature_type, TTYPE_STRUCT, "commonFeatures"); + m_common_features.decode(reader); + break; + } + default: throw ThriftInvalidField(field_id, __func__); + } + + feature_type = reader.readByte(); + } + return; +} + + +// Instantiate decoders. +template void GenericBatchPredictionRequest::decode(HashedDataRecordReader &reader); +template void GenericBatchPredictionRequest::decode(DataRecordReader &reader); + +} // namespace twml diff --git a/twml/libtwml/src/lib/BatchPredictionResponse.cpp b/twml/libtwml/src/lib/BatchPredictionResponse.cpp new file mode 100644 index 000000000..2a17d3605 --- /dev/null +++ b/twml/libtwml/src/lib/BatchPredictionResponse.cpp @@ -0,0 +1,125 @@ +#include "internal/endianutils.h" +#include "internal/error.h" +#include "internal/thrift.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include + +// When the number of predictions is very high, as some cases that Ads wants, the generic thrift +// encoder becomes super expensive because we have to deal with lua tables. +// This function is a special operation to efficiently write a batch prediction responses based on +// tensors. +namespace twml { + +BatchPredictionResponse::BatchPredictionResponse( + const Tensor &keys, const Tensor &values, + const Tensor &dense_keys, const std::vector &dense_values +) : keys_(keys), values_(values), dense_keys_(dense_keys), dense_values_(dense_values) { + // determine batch size + if (values_.getNumDims() > 0) { + batch_size_ = values_.getDim(0); + } else if (dense_keys_.getNumElements() < 1) { + throw twml::Error(TWML_ERR_TYPE, "Continuous values and dense tensors are both empty"); + } else if (dense_keys_.getNumElements() != dense_values_.size()) { + throw twml::Error(TWML_ERR_TYPE, "Number of tensors not equal to number of keys"); + } else { + // dim 0 for each tensor indexes batch elements + std::vector batch_sizes; + batch_sizes.reserve(dense_values_.size()); + + for (int i = 0; i < dense_values_.size(); i++) + batch_sizes.push_back(dense_values_.at(i).getDim(0)); + + if (std::adjacent_find( + batch_sizes.begin(), + batch_sizes.end(), + std::not_equal_to()) != batch_sizes.end()) + throw twml::Error(TWML_ERR_TYPE, "Batch size (dim 0) for all tensors must be the same"); + + batch_size_ = dense_values.at(0).getDim(0); + } +} + +void BatchPredictionResponse::encode(twml::ThriftWriter &thrift_writer) { + if (hasContinuous()) { + switch (values_.getType()) { + case TWML_TYPE_FLOAT: + serializePredictions(thrift_writer); + break; + case TWML_TYPE_DOUBLE: + serializePredictions(thrift_writer); + break; + default: + throw twml::Error(TWML_ERR_TYPE, "Predictions must be float or double."); + } + } else { + // dense tensor predictions + serializePredictions(thrift_writer); + } +} + +template +void BatchPredictionResponse::serializePredictions(twml::ThriftWriter &thrift_writer) { + twml::DataRecordWriter record_writer = twml::DataRecordWriter(thrift_writer); + + // start BatchPredictionResponse + thrift_writer.writeStructFieldHeader(TTYPE_LIST, BPR_PREDICTIONS); + thrift_writer.writeListHeader(TTYPE_STRUCT, getBatchSize()); + + for (int i = 0; i < getBatchSize(); i++) { + twml::DataRecord record = twml::DataRecord(); + + if (hasContinuous()) { + const T *values = values_.getData(); + const int64_t *local_keys = keys_.getData(); + const T *local_values = values + (i * getPredictionSize()); + record.addContinuous(local_keys, getPredictionSize(), local_values); + } + + if (hasDenseTensors()) { + const int64_t *local_dense_keys = dense_keys_.getData(); + + for (int j = 0; j < dense_keys_.getNumElements(); j++) { + const RawTensor &dense_value = dense_values_.at(j).getSlice(i); + record.addRawTensor(local_dense_keys[j], dense_value); + } + } + + record_writer.write(record); + } + + // end BatchPredictionResponse + thrift_writer.writeStructStop(); +} + +// calculate expected binary Thrift size (no memory is copied) +uint64_t BatchPredictionResponse::encodedSize() { + bool dry_mode = true; + twml::ThriftWriter dry_writer = twml::ThriftWriter(nullptr, 0, dry_mode); + encode(dry_writer); + return dry_writer.getBytesWritten(); +} + +void BatchPredictionResponse::write(Tensor &result) { + size_t result_size = result.getNumElements(); + uint8_t *result_data = result.getData(); + + if (result_size != this->encodedSize()) { + throw twml::Error(TWML_ERR_SIZE, "Sizes do not match"); + } + + twml::ThriftWriter writer = twml::ThriftWriter(result_data, result_size); + encode(writer); +} + +} // namespace twml diff --git a/twml/libtwml/src/lib/BlockFormatReader.cpp b/twml/libtwml/src/lib/BlockFormatReader.cpp new file mode 100644 index 000000000..98f49ac4f --- /dev/null +++ b/twml/libtwml/src/lib/BlockFormatReader.cpp @@ -0,0 +1,145 @@ +#include +#include +#include + +#define OFFSET_CHUNK (32768) +#define RECORDS_PER_BLOCK (100) + +#define WIRE_TYPE_VARINT (0) +#define WIRE_TYPE_64BIT (1) +#define WIRE_TYPE_LENGTH_PREFIXED (2) + +/* + This was all extracted from the ancient elephant bird scrolls + https://github.com/twitter/elephant-bird/blob/master/core/src/main/java/com/twitter/elephantbird/mapreduce/io/BinaryBlockReader.java +*/ + +#define MARKER_SIZE (16) +static uint8_t _marker[MARKER_SIZE] = { + 0x29, 0xd8, 0xd5, 0x06, 0x58, 0xcd, 0x4c, 0x29, + 0xb2, 0xbc, 0x57, 0x99, 0x21, 0x71, 0xbd, 0xff +}; + + +namespace twml { +BlockFormatReader::BlockFormatReader(): + record_size_(0), block_pos_(0), block_end_(0) { + memset(classname_, 0, sizeof(classname_)); +} + + +bool BlockFormatReader::next() { + record_size_ = read_one_record_size(); + if (record_size_ < 0) { + record_size_ = 0; + return false; + } + return true; +} + +int BlockFormatReader::read_int() { + uint8_t buff[4]; + if (read_bytes(buff, 1, 4) != 4) + return -1; + return static_cast(buff[0]) + | (static_cast(buff[1] << 8)) + | (static_cast(buff[2] << 16)) + | (static_cast(buff[3] << 24)); +} + +int BlockFormatReader::consume_marker(int scan) { + uint8_t buff[MARKER_SIZE]; + if (read_bytes(buff, 1, MARKER_SIZE) != MARKER_SIZE) + return 0; + + while (memcmp(buff, _marker, MARKER_SIZE) != 0) { + if (!scan) return 0; + memmove(buff, buff + 1, MARKER_SIZE - 1); + if (read_bytes(buff + MARKER_SIZE - 1, 1, 1) != 1) + return 0; + } + return 1; +} + +int BlockFormatReader::unpack_varint_i32() { + int value = 0; + for (int i = 0; i < 10; i++) { + uint8_t x; + if (read_bytes(&x, 1, 1) != 1) + return -1; + block_pos_++; + value |= (static_cast(x & 0x7F)) << (i * 7); + if ((x & 0x80) == 0) break; + } + return value; +} + + +int BlockFormatReader::unpack_tag_and_wiretype(uint32_t *tag, uint32_t *wiretype) { + uint8_t x; + if (read_bytes(&x, 1, 1) != 1) + return -1; + + block_pos_++; + *tag = (x & 0x7f) >> 3; + *wiretype = x & 7; + if ((x & 0x80) == 0) + return 0; + + return -1; +} + +int BlockFormatReader::unpack_string(char *out, uint64_t max_out_len) { + int len = unpack_varint_i32(); + if (len < 0) return -1; + uint64_t slen = len; + if (slen + 1 > max_out_len) return -1; + uint64_t n = read_bytes(out, 1, slen); + if (n != slen) return -1; + block_pos_ += n; + out[n] = 0; + return 0; +} + +int BlockFormatReader::read_one_record_size() { + for (int i = 0; i < 2; i++) { + if (block_end_ == 0) { + while (consume_marker(1)) { + int block_size = read_int(); + if (block_size > 0) { + block_pos_ = 0; + block_end_ = block_size; + uint32_t tag, wiretype; + if (unpack_tag_and_wiretype(&tag, &wiretype)) + throw std::invalid_argument("unsupported tag and wiretype"); + if (tag != 1 && wiretype != WIRE_TYPE_VARINT) + throw std::invalid_argument("unexpected tag and wiretype"); + int version = unpack_varint_i32(); + if (version != 1) + throw std::invalid_argument("unsupported version"); + if (unpack_tag_and_wiretype(&tag, &wiretype)) + throw std::invalid_argument("unsupported tag and wiretype"); + if (tag != 2 && wiretype != WIRE_TYPE_LENGTH_PREFIXED) + throw std::invalid_argument("unexpected tag and wiretype"); + if (unpack_string(classname_, sizeof(classname_)-1)) + throw std::invalid_argument("unsupported class name"); + break; + } + } + } + if (block_pos_ < block_end_) { + uint32_t tag, wiretype; + if (unpack_tag_and_wiretype(&tag, &wiretype)) + throw std::invalid_argument("unsupported tag and wiretype"); + if (tag != 3 && wiretype != WIRE_TYPE_LENGTH_PREFIXED) + throw std::invalid_argument("unexpected tag and wiretype"); + int record_size = unpack_varint_i32(); + block_pos_ += record_size; + return record_size; + } else { + block_end_ = 0; + } + } + return -1; +} +} // namespace twml diff --git a/twml/libtwml/src/lib/BlockFormatWriter.cpp b/twml/libtwml/src/lib/BlockFormatWriter.cpp new file mode 100644 index 000000000..d66e17351 --- /dev/null +++ b/twml/libtwml/src/lib/BlockFormatWriter.cpp @@ -0,0 +1,163 @@ +#include "internal/error.h" +#include +#include +#include + +#define WIRE_TYPE_LENGTH_PREFIXED (2) +#define WIRE_TYPE_VARINT (0) + +#ifndef PATH_MAX +#define PATH_MAX (8096) +#endif + +#define MARKER_SIZE (16) +static uint8_t _marker[MARKER_SIZE] = { + 0x29, 0xd8, 0xd5, 0x06, 0x58, 0xcd, 0x4c, 0x29, + 0xb2, 0xbc, 0x57, 0x99, 0x21, 0x71, 0xbd, 0xff +}; +namespace twml { + + BlockFormatWriter::BlockFormatWriter(const char *file_name, int record_per_block) : + file_name_(file_name), record_index_(0), records_per_block_(record_per_block) { + snprintf(temp_file_name_, PATH_MAX, "%s.block", file_name); + outputfile_ = fopen(file_name_, "a"); + } + + BlockFormatWriter::~BlockFormatWriter() { + fclose(outputfile_); + } + // TODO: use fstream + int BlockFormatWriter::pack_tag_and_wiretype(FILE *buffer, uint32_t tag, uint32_t wiretype) { + uint8_t x = ((tag & 0x0f) << 3) | (wiretype & 0x7); + size_t n = fwrite(&x, 1, 1, buffer); + if (n != 1) { + return -1; + } + return 0; + } + + int BlockFormatWriter::pack_varint_i32(FILE *buffer, int value) { + for (int i = 0; i < 10; i++) { + uint8_t x = value & 0x7F; + value = value >> 7; + if (value != 0) x |= 0x80; + size_t n = fwrite(&x, 1, 1, buffer); + if (n != 1) { + return -1; + } + if (value == 0) break; + } + return 0; + } + + int BlockFormatWriter::pack_string(FILE *buffer, const char *in, size_t in_len) { + if (pack_varint_i32(buffer, in_len)) return -1; + size_t n = fwrite(in, 1, in_len, buffer); + if (n != in_len) return -1; + return 0; + } + + int BlockFormatWriter::write_int(FILE *buffer, int value) { + uint8_t buff[4]; + buff[0] = value & 0xff; + buff[1] = (value >> 8) & 0xff; + buff[2] = (value >> 16) & 0xff; + buff[3] = (value >> 24) & 0xff; + size_t n = fwrite(buff, 1, 4, buffer); + if (n != 4) { + return -1; + } + return 0; + } + + int BlockFormatWriter::write(const char *class_name, const char *record, int record_len) { + if (record) { + record_index_++; + // The buffer holds max records_per_block_ of records (block). + FILE *buffer = fopen(temp_file_name_, "a"); + if (!buffer) return -1; + if (ftell(buffer) == 0) { + if (pack_tag_and_wiretype(buffer, 1, WIRE_TYPE_VARINT)) + throw std::invalid_argument("Error writting tag and wiretype"); + if (pack_varint_i32(buffer, 1)) + throw std::invalid_argument("Error writting varint_i32"); + if (pack_tag_and_wiretype(buffer, 2, WIRE_TYPE_LENGTH_PREFIXED)) + throw std::invalid_argument("Error writting tag and wiretype"); + if (pack_string(buffer, class_name, strlen(class_name))) + throw std::invalid_argument("Error writting class name"); + } + if (pack_tag_and_wiretype(buffer, 3, WIRE_TYPE_LENGTH_PREFIXED)) + throw std::invalid_argument("Error writtig tag and wiretype"); + if (pack_string(buffer, record, record_len)) + throw std::invalid_argument("Error writting record"); + fclose(buffer); + } + + if ((record_index_ % records_per_block_) == 0) { + flush(); + } + return 0; + } + + int BlockFormatWriter::flush() { + // Flush the records in the buffer to outputfile + FILE *buffer = fopen(temp_file_name_, "r"); + if (buffer) { + fseek(buffer, 0, SEEK_END); + int64_t block_size = ftell(buffer); + fseek(buffer, 0, SEEK_SET); + + if (fwrite(_marker, sizeof(_marker), 1, outputfile_) != 1) return 1; + if (write_int(outputfile_, block_size)) return 1; + uint8_t buff[4096]; + while (1) { + size_t n = fread(buff, 1, sizeof(buff), buffer); + if (n) { + size_t x = fwrite(buff, 1, n, outputfile_); + if (x != n) return 1; + } + if (n != sizeof(buff)) break; + } + fclose(buffer); + // Remove the buffer + if (remove(temp_file_name_)) return 1; + } + return 0; + } + + block_format_writer BlockFormatWriter::getHandle() { + return reinterpret_cast(this); + } + + BlockFormatWriter *getBlockFormatWriter(block_format_writer w) { + return reinterpret_cast(w); + } + +} // namespace twml + +twml_err block_format_writer_create(block_format_writer *w, const char *file_name, int records_per_block) { + HANDLE_EXCEPTIONS( + twml::BlockFormatWriter *writer = new twml::BlockFormatWriter(file_name, records_per_block); + *w = reinterpret_cast(writer);); + return TWML_ERR_NONE; +} + +twml_err block_format_write(block_format_writer w, const char *class_name, const char *record, int record_len) { + HANDLE_EXCEPTIONS( + twml::BlockFormatWriter *writer = twml::getBlockFormatWriter(w); + writer->write(class_name, record, record_len);); + return TWML_ERR_NONE; +} + +twml_err block_format_flush(block_format_writer w) { + HANDLE_EXCEPTIONS( + twml::BlockFormatWriter *writer = twml::getBlockFormatWriter(w); + writer->flush();); + return TWML_ERR_NONE; +} + +twml_err block_format_writer_delete(const block_format_writer w) { + HANDLE_EXCEPTIONS( + delete twml::getBlockFormatWriter(w);); + return TWML_ERR_NONE; +} diff --git a/twml/libtwml/src/lib/CMakeLists.txt b/twml/libtwml/src/lib/CMakeLists.txt new file mode 100644 index 000000000..6bf2a6e7c --- /dev/null +++ b/twml/libtwml/src/lib/CMakeLists.txt @@ -0,0 +1,36 @@ +set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}) +cmake_minimum_required(VERSION 2.8 FATAL_ERROR) +cmake_policy(VERSION 2.8) + + +set(TWML_VERSION "2.0.0") +string(REPLACE "." ";" TWML_VERSION_LIST ${TWML_VERSION}) +list(GET TWML_VERSION_LIST 0 TWML_SOVERSION) + +execute_process( + COMMAND + $ENV{LIBTWML_HOME}/src/ops/scripts/get_inc.sh + RESULT_VARIABLE + TF_RES + OUTPUT_VARIABLE + TF_INC) + +file(GLOB_RECURSE sources *.cpp) + +set (CMAKE_CXX_FLAGS "-Wall -std=c++11 ${CMAKE_CXX_FLAGS} -fPIC") + +add_library(twml STATIC ${sources}) + +target_include_directories( + twml + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/../../include + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${TF_INC} # Absail dependency from tensorflow + ) + +set_target_properties(twml PROPERTIES + VERSION "${TWML_VERSION}" + SOVERSION "${TWML_SOVERSION}" + ) diff --git a/twml/libtwml/src/lib/CPPLINT.cfg b/twml/libtwml/src/lib/CPPLINT.cfg new file mode 100644 index 000000000..dfe873a9d --- /dev/null +++ b/twml/libtwml/src/lib/CPPLINT.cfg @@ -0,0 +1 @@ +exclude_files=murmur_hash3.cpp \ No newline at end of file diff --git a/twml/libtwml/src/lib/DataRecord.cpp b/twml/libtwml/src/lib/DataRecord.cpp new file mode 100644 index 000000000..766422063 --- /dev/null +++ b/twml/libtwml/src/lib/DataRecord.cpp @@ -0,0 +1,72 @@ +#include "internal/thrift.h" +#include "internal/error.h" + +#include +#include +#include +#include + +#include +#include + +namespace twml { + +void DataRecord::decode(DataRecordReader &reader) { + uint8_t feature_type = reader.readByte(); + while (feature_type != TTYPE_STOP) { + int16_t field_id = reader.readInt16(); + switch (field_id) { + case DR_BINARY: + reader.readBinary(feature_type, this); + break; + case DR_CONTINUOUS: + reader.readContinuous(feature_type, this); + break; + case DR_DISCRETE: + reader.readDiscrete(feature_type, this); + break; + case DR_STRING: + reader.readString(feature_type, this); + break; + case DR_SPARSE_BINARY: + reader.readSparseBinary(feature_type, this); + break; + case DR_SPARSE_CONTINUOUS: + reader.readSparseContinuous(feature_type, this); + break; + case DR_BLOB: + reader.readBlob(feature_type, this); + break; + case DR_GENERAL_TENSOR: + reader.readTensor(feature_type, dynamic_cast(this)); + break; + case DR_SPARSE_TENSOR: + reader.readSparseTensor(feature_type, dynamic_cast(this)); + break; + default: + throw ThriftInvalidField(field_id, "DataRecord::decode"); + } + feature_type = reader.readByte(); + } +} + +void DataRecord::addLabel(int64_t id, double label) { + m_labels[id] = label; +} + +void DataRecord::addWeight(int64_t id, double val) { + m_weights[id] = val; +} + +void DataRecord::clear() { + std::fill(m_labels.begin(), m_labels.end(), std::nanf("")); + std::fill(m_weights.begin(), m_weights.end(), 0.0); + m_binary.clear(); + m_continuous.clear(); + m_discrete.clear(); + m_string.clear(); + m_sparsebinary.clear(); + m_sparsecontinuous.clear(); +} + +} // namespace twml diff --git a/twml/libtwml/src/lib/DataRecordReader.cpp b/twml/libtwml/src/lib/DataRecordReader.cpp new file mode 100644 index 000000000..f151e07a7 --- /dev/null +++ b/twml/libtwml/src/lib/DataRecordReader.cpp @@ -0,0 +1,230 @@ +#include "internal/thrift.h" +#include "internal/error.h" +#include +#include + +#include + +namespace twml { + +inline std::string bufferToString(int32_t str_len, const uint8_t *str) { + return std::string(str, str + str_len); +} + + +bool DataRecordReader::keepKey(const int64_t &key, int64_t &code) { + auto it = m_keep_map->find(key); + if (it == m_keep_map->end()) return false; + code = it->second; + return true; +} + +bool DataRecordReader::isLabel(const int64_t &key, int64_t &code) { + if (m_labels_map == nullptr) return false; + auto it = m_labels_map->find(key); + if (it == m_labels_map->end()) return false; + code = it->second; + return true; +} + +bool DataRecordReader::isWeight(const int64_t &key, int64_t &code) { + if (m_weights_map == nullptr) return false; + auto it = m_weights_map->find(key); + if (it == m_weights_map->end()) return false; + code = it->second; + return true; +} + + +void DataRecordReader::readBinary( + const int feature_type, + DataRecord *record) { + CHECK_THRIFT_TYPE(feature_type, TTYPE_SET, "type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_I64, "key_type"); + int32_t length = readInt32(); + int64_t id, code; +#ifdef USE_DENSE_HASH + record->m_binary.resize(2 * length); +#else + record->m_binary.reserve(2 * length); +#endif + for (int32_t i = 0; i < length; i++) { + id = readInt64(); + record->m_binary.insert(id); + if (isLabel(id, code)) { + record->addLabel(code); + } + } +} + +void DataRecordReader::readContinuous( + const int feature_type, + DataRecord *record) { + CHECK_THRIFT_TYPE(feature_type, TTYPE_MAP, "type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_I64, "key_type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_DOUBLE, "value_type"); + + int32_t length = readInt32(); + int64_t id, code; +#ifdef USE_DENSE_HASH + record->m_continuous.resize(2 * length); +#else + record->m_continuous.reserve(2 * length); +#endif + for (int32_t i = 0; i < length; i++) { + id = readInt64(); + double val = readDouble(); + if (!std::isnan(val)) { + record->m_continuous[id] = val; + } + if (isLabel(id, code)) { + record->addLabel(code, val); + } else if (isWeight(id, code)) { + record->addWeight(code, val); + } + } +} + +void DataRecordReader::readDiscrete( + const int feature_type, + DataRecord *record) { + CHECK_THRIFT_TYPE(feature_type, TTYPE_MAP, "type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_I64, "key_type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_I64, "value_type"); + + int32_t length = readInt32(); + int64_t id; +#ifdef USE_DENSE_HASH + record->m_discrete.resize(2 * length); +#else + record->m_discrete.reserve(2 * length); +#endif + for (int32_t i = 0; i < length; i++) { + id = readInt64(); + record->m_discrete[id] = readInt64(); + } +} + +void DataRecordReader::readString( + const int feature_type, + DataRecord *record) { + CHECK_THRIFT_TYPE(feature_type, TTYPE_MAP, "type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_I64, "key_type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_STRING, "value_type"); + int32_t length = readInt32(); + int64_t id; + +#ifdef USE_DENSE_HASH + record->m_string.resize(2 * length); +#else + record->m_string.reserve(2 * length); +#endif + + for (int32_t i = 0; i < length; i++) { + id = readInt64(); + const uint8_t *begin = nullptr; + int32_t str_len = getRawBuffer(&begin); + record->m_string[id] = bufferToString(str_len, begin); + } +} + +void DataRecordReader::readSparseBinary( + const int feature_type, + DataRecord *record) { + CHECK_THRIFT_TYPE(feature_type, TTYPE_MAP, "type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_I64, "key_type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_SET, "value_type"); + + int32_t length = readInt32(); + int64_t id, code; + +#ifdef USE_DENSE_HASH + record->m_sparsebinary.resize(2 * length); +#else + record->m_sparsebinary.reserve(2 * length); +#endif + + for (int32_t i = 0; i < length; i++) { + id = readInt64(); + CHECK_THRIFT_TYPE(readByte(), TTYPE_STRING, "set:key_type"); + int32_t set_length = readInt32(); + if (keepKey(id, code)) { + record->m_sparsebinary[id].reserve(set_length); + for (int32_t j = 0; j < set_length; j++) { + const uint8_t *begin = nullptr; + int32_t str_len = getRawBuffer(&begin); + record->m_sparsebinary[id].push_back(bufferToString(str_len, begin)); + } + } else { + for (int32_t j = 0; j < set_length; j++) { + int32_t str_len = readInt32(); + skipLength(str_len); + } + } + } +} + +void DataRecordReader::readSparseContinuous( + const int feature_type, + DataRecord *record) { + CHECK_THRIFT_TYPE(feature_type, TTYPE_MAP, "type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_I64, "key_type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_MAP, "value_type"); + + int32_t length = readInt32(); + int64_t id, code; + +#ifdef USE_DENSE_HASH + record->m_sparsecontinuous.resize(2 * length); +#else + record->m_sparsecontinuous.reserve(2 * length); +#endif + + for (int32_t i = 0; i < length; i++) { + id = readInt64(); + CHECK_THRIFT_TYPE(readByte(), TTYPE_STRING, "map::key_type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_DOUBLE, "map::value_type"); + int32_t map_length = readInt32(); + if (keepKey(id, code)) { + record->m_sparsecontinuous[id].reserve(map_length); + for (int32_t j = 0; j < map_length; j++) { + const uint8_t *begin = nullptr; + int32_t str_len = getRawBuffer(&begin); + double val = readDouble(); + if (!std::isnan(val)) { + record->m_sparsecontinuous[id].push_back({bufferToString(str_len, begin), val}); + } + } + } else { + for (int32_t j = 0; j < map_length; j++) { + int32_t str_len = readInt32(); + skipLength(str_len); + skip(); + } + } + } +} + +void DataRecordReader::readBlob( + const int feature_type, + DataRecord *record) { + CHECK_THRIFT_TYPE(feature_type, TTYPE_MAP, "type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_I64, "key_type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_STRING, "value_type"); + + int32_t length = readInt32(); + int64_t id, code; + for (int32_t i = 0; i < length; i++) { + id = readInt64(); + if (keepKey(id, code)) { + const uint8_t *begin = nullptr; + int32_t blob_len = getRawBuffer(&begin); + record->m_blob[id] = std::vector(begin, begin + blob_len); + } else { + int32_t str_len = readInt32(); + skipLength(str_len); + } + } +} + +} // namespace twml diff --git a/twml/libtwml/src/lib/DataRecordWriter.cpp b/twml/libtwml/src/lib/DataRecordWriter.cpp new file mode 100644 index 000000000..e12a50d48 --- /dev/null +++ b/twml/libtwml/src/lib/DataRecordWriter.cpp @@ -0,0 +1,162 @@ +#include "internal/error.h" +#include "internal/thrift.h" + +#include +#include +#include +#include +#include + +using namespace twml::io; + +namespace twml { + +void DataRecordWriter::writeBinary(twml::DataRecord &record) { + const DataRecord::BinaryFeatures bin_features = record.getBinary(); + + if (bin_features.size() > 0) { + m_thrift_writer.writeStructFieldHeader(TTYPE_SET, DR_BINARY); + m_thrift_writer.writeListHeader(TTYPE_I64, bin_features.size()); + + for (const auto &it : bin_features) { + m_thrift_writer.writeInt64(it); + } + } +} + +void DataRecordWriter::writeContinuous(twml::DataRecord &record) { + const DataRecord::ContinuousFeatures cont_features = record.getContinuous(); + + if (cont_features.size() > 0) { + m_thrift_writer.writeStructFieldHeader(TTYPE_MAP, DR_CONTINUOUS); + m_thrift_writer.writeMapHeader(TTYPE_I64, TTYPE_DOUBLE, cont_features.size()); + + for (const auto &it : cont_features) { + m_thrift_writer.writeInt64(it.first); + m_thrift_writer.writeDouble(it.second); + } + } +} + +void DataRecordWriter::writeDiscrete(twml::DataRecord &record) { + const DataRecord::DiscreteFeatures disc_features = record.getDiscrete(); + + if (disc_features.size() > 0) { + m_thrift_writer.writeStructFieldHeader(TTYPE_MAP, DR_DISCRETE); + m_thrift_writer.writeMapHeader(TTYPE_I64, TTYPE_I64, disc_features.size()); + + for (const auto &it : disc_features) { + m_thrift_writer.writeInt64(it.first); + m_thrift_writer.writeInt64(it.second); + } + } +} + +void DataRecordWriter::writeString(twml::DataRecord &record) { + const DataRecord::StringFeatures str_features = record.getString(); + + if (str_features.size() > 0) { + m_thrift_writer.writeStructFieldHeader(TTYPE_MAP, DR_STRING); + m_thrift_writer.writeMapHeader(TTYPE_I64, TTYPE_STRING, str_features.size()); + + + for (const auto &it : str_features) { + m_thrift_writer.writeInt64(it.first); + m_thrift_writer.writeString(it.second); + } + } +} + +// convert from internal representation list<(i64, string)> +// to Thrift representation map> +void DataRecordWriter::writeSparseBinaryFeatures(twml::DataRecord &record) { + const DataRecord::SparseBinaryFeatures sp_bin_features = record.getSparseBinary(); + + // write map> as Thrift + if (sp_bin_features.size() > 0) { + m_thrift_writer.writeStructFieldHeader(TTYPE_MAP, DR_SPARSE_BINARY); + m_thrift_writer.writeMapHeader(TTYPE_I64, TTYPE_SET, sp_bin_features.size()); + + for (auto key_vals : sp_bin_features) { + m_thrift_writer.writeInt64(key_vals.first); + m_thrift_writer.writeListHeader(TTYPE_STRING, key_vals.second.size()); + + for (auto name : key_vals.second) + m_thrift_writer.writeString(name); + } + } +} + +// convert from internal representation list<(i64, string, double)> +// to Thrift representation map> +void DataRecordWriter::writeSparseContinuousFeatures(twml::DataRecord &record) { + const DataRecord::SparseContinuousFeatures sp_cont_features = record.getSparseContinuous(); + + // write map> as Thrift + if (sp_cont_features.size() > 0) { + m_thrift_writer.writeStructFieldHeader(TTYPE_MAP, DR_SPARSE_CONTINUOUS); + m_thrift_writer.writeMapHeader(TTYPE_I64, TTYPE_MAP, sp_cont_features.size()); + + for (auto key_vals : sp_cont_features) { + m_thrift_writer.writeInt64(key_vals.first); + + if (key_vals.second.size() == 0) + throw IOError(IOError::MALFORMED_MEMORY_RECORD); + + m_thrift_writer.writeMapHeader(TTYPE_STRING, TTYPE_DOUBLE, key_vals.second.size()); + + for (auto map_str_double : key_vals.second) { + m_thrift_writer.writeString(map_str_double.first); + m_thrift_writer.writeDouble(map_str_double.second); + } + } + } +} + +void DataRecordWriter::writeBlobFeatures(twml::DataRecord &record) { + const DataRecord::BlobFeatures blob_features = record.getBlob(); + + if (blob_features.size() > 0) { + m_thrift_writer.writeStructFieldHeader(TTYPE_MAP, DR_BLOB); + m_thrift_writer.writeMapHeader(TTYPE_I64, TTYPE_STRING, blob_features.size()); + + for (const auto &it : blob_features) { + m_thrift_writer.writeInt64(it.first); + std::vector value = it.second; + m_thrift_writer.writeBinary(value.data(), value.size()); + } + } +} + +void DataRecordWriter::writeDenseTensors(twml::DataRecord &record) { + TensorRecord::RawTensors raw_tensors = record.getRawTensors(); + if (raw_tensors.size() > 0) { + m_thrift_writer.writeStructFieldHeader(TTYPE_MAP, DR_GENERAL_TENSOR); + m_tensor_writer.write(record); + } +} + +TWMLAPI uint32_t DataRecordWriter::getRecordsWritten() { + return m_records_written; +} + +TWMLAPI uint64_t DataRecordWriter::write(twml::DataRecord &record) { + uint64_t bytes_written_before = m_thrift_writer.getBytesWritten(); + + writeBinary(record); + writeContinuous(record); + writeDiscrete(record); + writeString(record); + writeSparseBinaryFeatures(record); + writeSparseContinuousFeatures(record); + writeBlobFeatures(record); + writeDenseTensors(record); + // TODO add sparse tensor field + + m_thrift_writer.writeStructStop(); + m_records_written++; + + return m_thrift_writer.getBytesWritten() - bytes_written_before; +} + +} // namespace twml diff --git a/twml/libtwml/src/lib/HashedDataRecord.cpp b/twml/libtwml/src/lib/HashedDataRecord.cpp new file mode 100644 index 000000000..6bbecee70 --- /dev/null +++ b/twml/libtwml/src/lib/HashedDataRecord.cpp @@ -0,0 +1,80 @@ +#include "internal/thrift.h" +#include "internal/error.h" + +#include +#include +#include + +#include +#include +#include + +namespace twml { + +void HashedDataRecord::decode(HashedDataRecordReader &reader) { + uint8_t feature_type = reader.readByte(); + while (feature_type != TTYPE_STOP) { + int16_t field_id = reader.readInt16(); + switch (field_id) { + case DR_BINARY: + reader.readBinary(feature_type, this); + break; + case DR_CONTINUOUS: + reader.readContinuous(feature_type, this); + break; + case DR_DISCRETE: + reader.readDiscrete(feature_type, this); + break; + case DR_STRING: + reader.readString(feature_type, this); + break; + case DR_SPARSE_BINARY: + reader.readSparseBinary(feature_type, this); + break; + case DR_SPARSE_CONTINUOUS: + reader.readSparseContinuous(feature_type, this); + break; + case DR_BLOB: + reader.readBlob(feature_type, this); + break; + case DR_GENERAL_TENSOR: + reader.readTensor(feature_type, dynamic_cast(this)); + break; + case DR_SPARSE_TENSOR: + reader.readSparseTensor(feature_type, dynamic_cast(this)); + break; + default: + throw ThriftInvalidField(field_id, "HashedDataRecord::readThrift"); + } + feature_type = reader.readByte(); + } +} + +void HashedDataRecord::addKey(int64_t key, int64_t transformed_key, + int64_t code, uint8_t type, double value) { + m_keys.push_back(key); + m_transformed_keys.push_back(transformed_key); + m_values.push_back(value); + m_codes.push_back(code); + m_types.push_back(type); +} + +void HashedDataRecord::addLabel(int64_t id, double label) { + m_labels[id] = label; +} + +void HashedDataRecord::addWeight(int64_t id, double val) { + m_weights[id] = val; +} + +void HashedDataRecord::clear() { + std::fill(m_labels.begin(), m_labels.end(), std::nanf("")); + std::fill(m_weights.begin(), m_weights.end(), 0.0); + m_keys.clear(); + m_transformed_keys.clear(); + m_values.clear(); + m_codes.clear(); + m_types.clear(); +} + +} // namespace twml \ No newline at end of file diff --git a/twml/libtwml/src/lib/HashedDataRecordReader.cpp b/twml/libtwml/src/lib/HashedDataRecordReader.cpp new file mode 100644 index 000000000..93c86001b --- /dev/null +++ b/twml/libtwml/src/lib/HashedDataRecordReader.cpp @@ -0,0 +1,218 @@ +#include "internal/thrift.h" +#include "internal/error.h" + +#include +#include +#include +#include + +namespace twml { + +bool HashedDataRecordReader::keepId(const int64_t &key, int64_t &code) { + auto it = m_keep_map->find(key); + if (it == m_keep_map->end()) return false; + code = it->second; + return true; +} + +bool HashedDataRecordReader::isLabel(const int64_t &key, int64_t &code) { + if (m_labels_map == nullptr) return false; + auto it = m_labels_map->find(key); + if (it == m_labels_map->end()) return false; + code = it->second; + return true; +} + +bool HashedDataRecordReader::isWeight(const int64_t &key, int64_t &code) { + if (m_weights_map == nullptr) return false; + auto it = m_weights_map->find(key); + if (it == m_weights_map->end()) return false; + code = it->second; + return true; +} + +void HashedDataRecordReader::readBinary( + const int feature_type, + HashedDataRecord *record) { + CHECK_THRIFT_TYPE(feature_type, TTYPE_SET, "type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_I64, "key_type"); + + int32_t length = readInt32(); + record->extendSize(length); + int64_t id, code; + for (int32_t i = 0; i < length; i++) { + id = readInt64(); + if (keepId(id, code)) { + record->addKey(id, id, code, DR_BINARY); + } else if (isLabel(id, code)) { + record->addLabel(code); + } + } +} + +void HashedDataRecordReader::readContinuous( + const int feature_type, + HashedDataRecord *record) { + CHECK_THRIFT_TYPE(feature_type, TTYPE_MAP, "type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_I64, "key_type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_DOUBLE, "value_type"); + + int32_t length = readInt32(); + record->extendSize(length); + int64_t id, code; + for (int32_t i = 0; i < length; i++) { + id = readInt64(); + if (keepId(id, code)) { + double value = readDouble(); + if (!std::isnan(value)) { + record->addKey(id, id, code, DR_CONTINUOUS, value); + } + } else if (isLabel(id, code)) { + record->addLabel(code, readDouble()); + } else if (isWeight(id, code)) { + record->addWeight(code, readDouble()); + } else { + skip(); + } + } +} + +void HashedDataRecordReader::readDiscrete( + const int feature_type, + HashedDataRecord *record) { + CHECK_THRIFT_TYPE(feature_type, TTYPE_MAP, "type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_I64, "key_type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_I64, "value_type"); + + int32_t length = readInt32(); + record->extendSize(length); + int64_t id, code; + for (int32_t i = 0; i < length; i++) { + id = readInt64(); + if (keepId(id, code)) { + int64_t transformed_key = mixDiscreteIdAndValue(id, readInt64()); + record->addKey(id, transformed_key, code, DR_DISCRETE); + } else { + skip(); + } + } +} + +void HashedDataRecordReader::readString( + const int feature_type, + HashedDataRecord *record) { + CHECK_THRIFT_TYPE(feature_type, TTYPE_MAP, "type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_I64, "key_type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_STRING, "value_type"); + + int32_t length = readInt32(); + record->extendSize(length); + int64_t id, code; + for (int32_t i = 0; i < length; i++) { + id = readInt64(); + if (keepId(id, code)) { + const uint8_t *begin = nullptr; + int32_t str_len = getRawBuffer(&begin); + int64_t transformed_key = mixStringIdAndValue(id, str_len, begin); + record->addKey(id, transformed_key, code, DR_STRING); + } else { + int32_t str_len = readInt32(); + skipLength(str_len); + } + } +} + +void HashedDataRecordReader::readSparseBinary( + const int feature_type, + HashedDataRecord *record) { + CHECK_THRIFT_TYPE(feature_type, TTYPE_MAP, "type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_I64, "key_type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_SET, "value_type"); + + int32_t length = readInt32(); + record->extendSize(length); + int64_t id, code; + for (int32_t i = 0; i < length; i++) { + id = readInt64(); + if (keepId(id, code)) { + CHECK_THRIFT_TYPE(readByte(), TTYPE_STRING, "set:key_type"); + int32_t set_length = readInt32(); + for (int32_t j = 0; j < set_length; j++) { + const uint8_t *begin = nullptr; + int32_t str_len = getRawBuffer(&begin); + int64_t transformed_key = mixStringIdAndValue(id, str_len, begin); + record->addKey(id, transformed_key, code, DR_SPARSE_BINARY); + } + } else { + CHECK_THRIFT_TYPE(readByte(), TTYPE_STRING, "set:key_type"); + int32_t set_length = readInt32(); + for (int32_t j = 0; j < set_length; j++) { + int32_t str_len = readInt32(); + skipLength(str_len); + } + } + } +} + +void HashedDataRecordReader::readSparseContinuous( + const int feature_type, + HashedDataRecord *record) { + CHECK_THRIFT_TYPE(feature_type, TTYPE_MAP, "type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_I64, "key_type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_MAP, "value_type"); + + int32_t length = readInt32(); + record->extendSize(length); + int64_t id, code; + for (int32_t i = 0; i < length; i++) { + id = readInt64(); + if (keepId(id, code)) { + CHECK_THRIFT_TYPE(readByte(), TTYPE_STRING, "map::key_type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_DOUBLE, "map::value_type"); + int32_t map_length = readInt32(); + for (int32_t j = 0; j < map_length; j++) { + const uint8_t *begin = nullptr; + int32_t str_len = getRawBuffer(&begin); + int64_t transformed_key = 0; + switch(m_decode_mode) { + case DecodeMode::hash_fname_and_valname: + transformed_key = mixStringIdAndValue(id, str_len, begin); + break; + default: // m_decode_mode == DecodeMode::hash_valname == 0 is default + twml_get_feature_id(&transformed_key, str_len, reinterpret_cast(begin)); + } + double value = readDouble(); + if (!std::isnan(value)) { + record->addKey(id, transformed_key, code, DR_SPARSE_CONTINUOUS, value); + } + } + } else { + CHECK_THRIFT_TYPE(readByte(), TTYPE_STRING, "map::key_type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_DOUBLE, "map::value_type"); + int32_t map_length = readInt32(); + for (int32_t j = 0; j < map_length; j++) { + int32_t str_len = readInt32(); + skipLength(str_len); + skip(); + } + } + } +} + +void HashedDataRecordReader::readBlob( + const int feature_type, + HashedDataRecord *record) { + CHECK_THRIFT_TYPE(feature_type, TTYPE_MAP, "type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_I64, "key_type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_STRING, "value_type"); + + int32_t length = readInt32(); + int64_t id; + for (int32_t i = 0; i < length; i++) { + // Skips the BlobFeatures if they are defined or not in the FeatureConfig + id = readInt64(); + int32_t str_len = readInt32(); + skipLength(str_len); + } +} +} // namespace twml \ No newline at end of file diff --git a/twml/libtwml/src/lib/Hashmap.cpp b/twml/libtwml/src/lib/Hashmap.cpp new file mode 100644 index 000000000..4086e8a16 --- /dev/null +++ b/twml/libtwml/src/lib/Hashmap.cpp @@ -0,0 +1,380 @@ +#include "internal/khash.h" +#include "internal/error.h" +#include +#include +#include + +namespace twml { + HashMap::HashMap() : + m_hashmap(nullptr) { + TWML_CHECK(twml_hashmap_create(&m_hashmap), "Failed to create HashMap"); + } + + HashMap::~HashMap() { + // Do not throw exceptions from the destructor + twml_hashmap_delete(m_hashmap); + } + + void HashMap::clear() { + TWML_CHECK(twml_hashmap_clear(m_hashmap), "Failed to clear HashMap"); + } + + uint64_t HashMap::size() const { + uint64_t size; + TWML_CHECK(twml_hashmap_get_size(&size, m_hashmap), "Failed to get HashMap size"); + return size; + } + + int8_t HashMap::insert(const HashKey_t key) { + int8_t result; + TWML_CHECK(twml_hashmap_insert_key(&result, m_hashmap, key), + "Failed to insert key"); + return result; + } + + int8_t HashMap::insert(const HashKey_t key, const HashKey_t val) { + int8_t result; + TWML_CHECK(twml_hashmap_insert_key_and_value(&result, m_hashmap, key, val), + "Failed to insert key"); + return result; + } + + int8_t HashMap::get(HashVal_t &val, const HashKey_t key) const { + int8_t result; + TWML_CHECK(twml_hashmap_get_value(&result, &val, m_hashmap, key), + "Failed to insert key,value pair"); + return result; + } + + void HashMap::insert(Tensor &mask, const Tensor keys) { + TWML_CHECK(twml_hashmap_insert_keys(mask.getHandle(), m_hashmap, keys.getHandle()), + "Failed to insert keys tensor"); + } + + void HashMap::insert(Tensor &mask, const Tensor keys, const Tensor vals) { + TWML_CHECK(twml_hashmap_insert_keys_and_values(mask.getHandle(), m_hashmap, + keys.getHandle(), vals.getHandle()), + "Failed to insert keys,values tensor pair"); + } + + void HashMap::remove(const Tensor keys) { + TWML_CHECK(twml_hashmap_remove_keys(m_hashmap, keys.getHandle()), + "Failed to remove keys tensor"); + } + + void HashMap::get(Tensor &mask, Tensor &vals, const Tensor keys) const { + TWML_CHECK(twml_hashmap_get_values(mask.getHandle(), vals.getHandle(), + m_hashmap, keys.getHandle()), + "Failed to get values tensor"); + } + + void HashMap::getInplace(Tensor &mask, Tensor &keys_vals) const { + TWML_CHECK(twml_hashmap_get_values_inplace(mask.getHandle(), + keys_vals.getHandle(), + m_hashmap), + "Failed to get values tensor"); + } + + void HashMap::toTensors(Tensor &keys, Tensor &vals) const { + TWML_CHECK(twml_hashmap_to_tensors(keys.getHandle(), + vals.getHandle(), + m_hashmap), + "Failed to get keys,values tensors from HashMap"); + } +} // namespace twml + +using twml::HashKey_t; +using twml::HashVal_t; + +KHASH_MAP_INIT_INT64(HashKey_t, HashVal_t); +typedef khash_t(HashKey_t)* hash_map_t; + + +twml_err twml_hashmap_create(twml_hashmap *hashmap) { + hash_map_t *h = reinterpret_cast(hashmap); + *h = kh_init(HashKey_t); + return TWML_ERR_NONE; +} + +twml_err twml_hashmap_clear(const twml_hashmap hashmap) { + hash_map_t h = (hash_map_t)hashmap; + kh_clear(HashKey_t, h); + return TWML_ERR_NONE; +} + +twml_err twml_hashmap_get_size(uint64_t *size, const twml_hashmap hashmap) { + hash_map_t h = (hash_map_t)hashmap; + *size = kh_size(h); + return TWML_ERR_NONE; +} + + +twml_err twml_hashmap_delete(const twml_hashmap hashmap) { + hash_map_t h = (hash_map_t)hashmap; + kh_destroy(HashKey_t, h); + return TWML_ERR_NONE; +} + +// insert, remove, get single key / value +twml_err twml_hashmap_insert_key(int8_t *mask, + const twml_hashmap hashmap, + const HashKey_t key) { + hash_map_t h = (hash_map_t)hashmap; + int ret = 0; + khiter_t k = kh_put(HashKey_t, h, key, &ret); + *mask = ret >= 0; + if (*mask) { + HashVal_t v = kh_size(h); + kh_value(h, k) = v; + } + return TWML_ERR_NONE; +} + +twml_err twml_hashmap_insert_key_and_value(int8_t *mask, twml_hashmap hashmap, + const HashKey_t key, const HashVal_t val) { + hash_map_t h = (hash_map_t)hashmap; + int ret = 0; + khiter_t k = kh_put(HashKey_t, h, key, &ret); + *mask = ret >= 0; + if (*mask) { + kh_value(h, k) = val; + } + return TWML_ERR_NONE; +} + +twml_err twml_hashmap_remove_key(const twml_hashmap hashmap, + const HashKey_t key) { + hash_map_t h = (hash_map_t)hashmap; + khiter_t k = kh_get(HashKey_t, h, key); + if (k != kh_end(h)) { + kh_del(HashKey_t, h, k); + } + return TWML_ERR_NONE; +} + +twml_err twml_hashmap_get_value(int8_t *mask, HashVal_t *val, + const twml_hashmap hashmap, const HashKey_t key) { + hash_map_t h = (hash_map_t)hashmap; + khiter_t k = kh_get(HashKey_t, h, key); + if (k == kh_end(h)) { + *mask = false; + } else { + *val = kh_value(h, k); + *mask = true; + } + return TWML_ERR_NONE; +} + +// insert, get, remove tensors of keys / values +twml_err twml_hashmap_insert_keys(twml_tensor masks, + const twml_hashmap hashmap, + const twml_tensor keys) { + auto masks_tensor = twml::getTensor(masks); + auto keys_tensor = twml::getConstTensor(keys); + + if (masks_tensor->getType() != TWML_TYPE_INT8) { + return TWML_ERR_TYPE; + } + + if (keys_tensor->getType() != TWML_TYPE_INT64) { + return TWML_ERR_TYPE; + } + + if (keys_tensor->getNumElements() != masks_tensor->getNumElements()) { + return TWML_ERR_SIZE; + } + + int8_t *mptr = masks_tensor->getData(); + const HashKey_t *kptr = keys_tensor->getData(); + + uint64_t num_elements = keys_tensor->getNumElements(); + + hash_map_t h = (hash_map_t)hashmap; + for (uint64_t i = 0; i < num_elements; i++) { + int ret = 0; + khiter_t k = kh_put(HashKey_t, h, kptr[i], &ret); + mptr[i] = ret >= 0; + if (mptr[i]) { + HashVal_t v = kh_size(h); + kh_value(h, k) = v; + } + } + return TWML_ERR_NONE; +} + +twml_err twml_hashmap_insert_keys_and_values(twml_tensor masks, + twml_hashmap hashmap, + const twml_tensor keys, + const twml_tensor vals) { + auto masks_tensor = twml::getTensor(masks); + auto keys_tensor = twml::getConstTensor(keys); + auto vals_tensor = twml::getConstTensor(vals); + + if (masks_tensor->getType() != TWML_TYPE_INT8) { + return TWML_ERR_TYPE; + } + + if (keys_tensor->getType() != TWML_TYPE_INT64) { + return TWML_ERR_TYPE; + } + + if (vals_tensor->getType() != TWML_TYPE_INT64) { + return TWML_ERR_TYPE; + } + + if (keys_tensor->getNumElements() != vals_tensor->getNumElements() || + keys_tensor->getNumElements() != masks_tensor->getNumElements()) { + return TWML_ERR_SIZE; + } + + int8_t *mptr = masks_tensor->getData(); + const HashKey_t *kptr = keys_tensor->getData(); + const HashVal_t *vptr = twml::getConstTensor(vals)->getData(); + + uint64_t num_elements = keys_tensor->getNumElements(); + + hash_map_t h = (hash_map_t)hashmap; + for (uint64_t i = 0; i < num_elements; i++) { + int ret = 0; + khiter_t k = kh_put(HashKey_t, h, kptr[i], &ret); + mptr[i] = ret >= 0; + if (mptr[i]) { + kh_value(h, k) = vptr[i]; + } + } + return TWML_ERR_NONE; +} + +twml_err twml_hashmap_remove_keys(const twml_hashmap hashmap, + const twml_tensor keys) { + auto keys_tensor = twml::getConstTensor(keys); + + if (keys_tensor->getType() != TWML_TYPE_INT64) { + return TWML_ERR_TYPE; + } + + const HashKey_t *kptr = keys_tensor->getData(); + uint64_t num_elements = keys_tensor->getNumElements(); + + hash_map_t h = (hash_map_t)hashmap; + for (uint64_t i = 0; i < num_elements; i++) { + khiter_t k = kh_get(HashKey_t, h, kptr[i]); + if (k != kh_end(h)) { + kh_del(HashKey_t, h, kptr[i]); + } + } + return TWML_ERR_NONE; +} + +twml_err twml_hashmap_get_values(twml_tensor masks, twml_tensor vals, + const twml_hashmap hashmap, const twml_tensor keys) { + auto masks_tensor = twml::getTensor(masks); + auto vals_tensor = twml::getTensor(vals); + auto keys_tensor = twml::getConstTensor(keys); + + if (masks_tensor->getType() != TWML_TYPE_INT8) { + return TWML_ERR_TYPE; + } + + if (keys_tensor->getType() != TWML_TYPE_INT64) { + return TWML_ERR_TYPE; + } + + if (vals_tensor->getType() != TWML_TYPE_INT64) { + return TWML_ERR_TYPE; + } + + if (keys_tensor->getNumElements() != vals_tensor->getNumElements() || + keys_tensor->getNumElements() != masks_tensor->getNumElements()) { + return TWML_ERR_SIZE; + } + + int8_t *mptr = masks_tensor->getData(); + HashVal_t *vptr = vals_tensor->getData(); + const HashKey_t *kptr = keys_tensor->getData(); + + uint64_t num_elements = keys_tensor->getNumElements(); + + hash_map_t h = (hash_map_t)hashmap; + for (uint64_t i = 0; i < num_elements; i++) { + khiter_t k = kh_get(HashKey_t, h, kptr[i]); + if (k == kh_end(h)) { + mptr[i] = false; + } else { + mptr[i] = true; + vptr[i] = kh_value(h, k); + } + } + return TWML_ERR_NONE; +} + +twml_err twml_hashmap_get_values_inplace(twml_tensor masks, twml_tensor keys_vals, + const twml_hashmap hashmap) { + auto masks_tensor = twml::getTensor(masks); + auto keys_tensor = twml::getTensor(keys_vals); + + if (masks_tensor->getType() != TWML_TYPE_INT8) { + return TWML_ERR_TYPE; + } + + if (keys_tensor->getType() != TWML_TYPE_INT64) { + return TWML_ERR_TYPE; + } + + if (keys_tensor->getNumElements() != masks_tensor->getNumElements()) { + return TWML_ERR_SIZE; + } + + int8_t *mptr = masks_tensor->getData(); + HashKey_t *kptr = keys_tensor->getData(); + + uint64_t num_elements = keys_tensor->getNumElements(); + + hash_map_t h = (hash_map_t)hashmap; + for (uint64_t i = 0; i < num_elements; i++) { + khiter_t k = kh_get(HashKey_t, h, kptr[i]); + if (k == kh_end(h)) { + mptr[i] = false; + } else { + mptr[i] = true; + kptr[i] = kh_value(h, k); + } + } + return TWML_ERR_NONE; +} + +twml_err twml_hashmap_to_tensors(twml_tensor keys, twml_tensor vals, + const twml_hashmap hashmap) { + hash_map_t h = (hash_map_t)hashmap; + const uint64_t size = kh_size(h); + + auto keys_tensor = twml::getTensor(keys); + auto vals_tensor = twml::getTensor(vals); + + if (keys_tensor->getType() != TWML_TYPE_INT64) { + return TWML_ERR_TYPE; + } + + if (vals_tensor->getType() != TWML_TYPE_INT64) { + return TWML_ERR_TYPE; + } + + if (size != keys_tensor->getNumElements() || + size != vals_tensor->getNumElements()) { + return TWML_ERR_SIZE; + } + + HashKey_t *kptr = keys_tensor->getData(); + HashVal_t *vptr = vals_tensor->getData(); + + HashKey_t key, i = 0; + HashKey_t val; + + kh_foreach(h, key, val, { + kptr[i] = key; + vptr[i] = val; + i++; + }); + + return TWML_ERR_NONE; +} diff --git a/twml/libtwml/src/lib/Tensor.cpp b/twml/libtwml/src/lib/Tensor.cpp new file mode 100644 index 000000000..d610d9316 --- /dev/null +++ b/twml/libtwml/src/lib/Tensor.cpp @@ -0,0 +1,191 @@ +#include "internal/error.h" +#include +#include +#include +#include +#include + +namespace twml { + +using std::vector; + +Tensor::Tensor(void *data, int ndims, const uint64_t *dims, const uint64_t *strides, twml_type type) : + m_type(type), m_data(data), + m_dims(dims, dims + ndims), + m_strides(strides, strides + ndims) { +} + +Tensor::Tensor(void *data, + const vector &dims, + const vector &strides, + twml_type type) : + m_type(type), m_data(data), + m_dims(dims.begin(), dims.end()), + m_strides(strides.begin(), strides.end()) { + if (dims.size() != strides.size()) { + throw twml::Error(TWML_ERR_SIZE, "The number size of dims and strides don't match"); + } +} + +int Tensor::getNumDims() const { + return static_cast(m_dims.size()); +} + +uint64_t Tensor::getDim(int id) const { + if (id >= this->getNumDims()) { + throw twml::Error(TWML_ERR_SIZE, "Requested dimension exceeds tensor dimension"); + } + return m_dims[id]; +} + +uint64_t Tensor::getStride(int id) const { + if (id >= this->getNumDims()) { + throw twml::Error(TWML_ERR_SIZE, "Requested dimension exceeds tensor dimension"); + } + return m_strides[id]; +} + +uint64_t Tensor::getNumElements() const { + return std::accumulate(m_dims.begin(), m_dims.end(), 1, std::multiplies()); +} + +twml_type Tensor::getType() const { + return m_type; +} + +twml_tensor Tensor::getHandle() { + return reinterpret_cast(this); +} + +const twml_tensor Tensor::getHandle() const { + return reinterpret_cast(const_cast(this)); +} + +const Tensor *getConstTensor(const twml_tensor t) { + return reinterpret_cast(t); +} + +Tensor *getTensor(twml_tensor t) { + return reinterpret_cast(t); +} + +#define INSTANTIATE(T) \ + template<> TWMLAPI T *Tensor::getData() { \ + if ((twml_type)Type::type != m_type) { \ + throw twml::Error(TWML_ERR_TYPE, \ + "Requested invalid type"); \ + } \ + return reinterpret_cast(m_data); \ + } \ + template<> TWMLAPI const T *Tensor::getData() const { \ + if ((twml_type)Type::type != m_type) { \ + throw twml::Error(TWML_ERR_TYPE, \ + "Requested invalid type"); \ + } \ + return (const T *)m_data; \ + } \ + +INSTANTIATE(int32_t) +INSTANTIATE(int64_t) +INSTANTIATE(int8_t) +INSTANTIATE(uint8_t) +INSTANTIATE(float) +INSTANTIATE(double) +INSTANTIATE(bool) +INSTANTIATE(std::string) + +// This is used for the C api. No checks needed for void. +template<> TWMLAPI void *Tensor::getData() { + return m_data; +} +template<> TWMLAPI const void *Tensor::getData() const { + return (const void *)m_data; +} + +std::string getTypeName(twml_type type) { + switch (type) { + case TWML_TYPE_FLOAT32 : return "float32"; + case TWML_TYPE_FLOAT64 : return "float64"; + case TWML_TYPE_INT32 : return "int32"; + case TWML_TYPE_INT64 : return "int64"; + case TWML_TYPE_INT8 : return "int8"; + case TWML_TYPE_UINT8 : return "uint8"; + case TWML_TYPE_BOOL : return "bool"; + case TWML_TYPE_STRING : return "string"; + case TWML_TYPE_UNKNOWN : return "Unknown type"; + } + throw twml::Error(TWML_ERR_TYPE, "Uknown type"); +} + +uint64_t getSizeOf(twml_type dtype) { + switch (dtype) { + case TWML_TYPE_FLOAT : return 4; + case TWML_TYPE_DOUBLE : return 8; + case TWML_TYPE_INT64 : return 8; + case TWML_TYPE_INT32 : return 4; + case TWML_TYPE_UINT8 : return 1; + case TWML_TYPE_BOOL : return 1; + case TWML_TYPE_INT8 : return 1; + case TWML_TYPE_STRING : + throw twml::Error(TWML_ERR_THRIFT, "getSizeOf not supported for strings"); + case TWML_TYPE_UNKNOWN: + throw twml::Error(TWML_ERR_THRIFT, "Can't get size of unknown types"); + } + throw twml::Error(TWML_ERR_THRIFT, "Invalid twml_type"); +} + +} // namespace twml + +twml_err twml_tensor_create(twml_tensor *t, void *data, int ndims, uint64_t *dims, + uint64_t *strides, twml_type type) { + HANDLE_EXCEPTIONS( + twml::Tensor *res = new twml::Tensor(data, ndims, dims, strides, type); + *t = reinterpret_cast(res);); + return TWML_ERR_NONE; +} + +twml_err twml_tensor_delete(const twml_tensor t) { + HANDLE_EXCEPTIONS( + delete twml::getConstTensor(t);); + return TWML_ERR_NONE; +} + +twml_err twml_tensor_get_type(twml_type *type, const twml_tensor t) { + HANDLE_EXCEPTIONS( + *type = twml::getConstTensor(t)->getType();); + return TWML_ERR_NONE; +} + +twml_err twml_tensor_get_data(void **data, const twml_tensor t) { + HANDLE_EXCEPTIONS( + *data = twml::getTensor(t)->getData();); + return TWML_ERR_NONE; +} + +twml_err twml_tensor_get_dim(uint64_t *dim, const twml_tensor t, int id) { + HANDLE_EXCEPTIONS( + const twml::Tensor *tensor = twml::getConstTensor(t); + *dim = tensor->getDim(id);); + return TWML_ERR_NONE; +} + +twml_err twml_tensor_get_stride(uint64_t *stride, const twml_tensor t, int id) { + HANDLE_EXCEPTIONS( + const twml::Tensor *tensor = twml::getConstTensor(t); + *stride = tensor->getStride(id);); + return TWML_ERR_NONE; +} + +twml_err twml_tensor_get_num_dims(int *ndim, const twml_tensor t) { + HANDLE_EXCEPTIONS( + const twml::Tensor *tensor = twml::getConstTensor(t); + *ndim = tensor->getNumDims();); + return TWML_ERR_NONE; +} + +twml_err twml_tensor_get_num_elements(uint64_t *nelements, const twml_tensor t) { + HANDLE_EXCEPTIONS( + const twml::Tensor *tensor = twml::getConstTensor(t); + *nelements = tensor->getNumElements();); + return TWML_ERR_NONE; +} diff --git a/twml/libtwml/src/lib/TensorRecordReader.cpp b/twml/libtwml/src/lib/TensorRecordReader.cpp new file mode 100644 index 000000000..3ffb1b98a --- /dev/null +++ b/twml/libtwml/src/lib/TensorRecordReader.cpp @@ -0,0 +1,323 @@ +#include "internal/thrift.h" +#include "internal/error.h" +#include + +#include +#include + +namespace twml { + +template struct TensorTraits; + +#define INSTANTIATE(TYPE, THRIFT_TYPE, TWML_TYPE) \ + template<> struct TensorTraits { \ + static const TTYPES ThriftType = THRIFT_TYPE; \ + static const twml_type TwmlType = TWML_TYPE; \ + }; \ + +INSTANTIATE(int64_t, TTYPE_I64, TWML_TYPE_INT64) +INSTANTIATE(int32_t, TTYPE_I32, TWML_TYPE_INT32) +INSTANTIATE(double, TTYPE_DOUBLE, TWML_TYPE_DOUBLE) +INSTANTIATE(bool, TTYPE_BOOL, TWML_TYPE_BOOL) + +static +std::vector calcStrides(const std::vector &shape) { + int ndims = static_cast(shape.size()); + std::vector strides(ndims); + uint64_t stride = 1; + for (int i = ndims-1; i >= 0; i--) { + strides[i] = stride; + stride *= shape[i]; + } + return strides; +} + +static twml_type getTwmlType(int dtype) { + // Convert tensor.thrift enum to twml enum + switch (dtype) { + case DATA_TYPE_FLOAT: + return TWML_TYPE_FLOAT; + case DATA_TYPE_DOUBLE: + return TWML_TYPE_DOUBLE; + case DATA_TYPE_INT64: + return TWML_TYPE_INT64; + case DATA_TYPE_INT32: + return TWML_TYPE_INT32; + case DATA_TYPE_UINT8: + return TWML_TYPE_UINT8; + case DATA_TYPE_STRING: + return TWML_TYPE_STRING; + case DATA_TYPE_BOOL: + return TWML_TYPE_BOOL; + } + return TWML_TYPE_UNKNOWN; +} + +std::vector TensorRecordReader::readShape() { + int32_t length = readInt32(); + + std::vector shape; + shape.reserve(length); + for (int32_t i = 0; i < length; i++) { + shape.push_back(static_cast(readInt64())); + } + + return shape; +} + +template +RawTensor TensorRecordReader::readTypedTensor() { + std::vector shape; + int32_t length = 0; + const uint8_t *data = nullptr; + uint64_t raw_length = 0; + uint8_t field_type = TTYPE_STOP; + + while ((field_type = readByte()) != TTYPE_STOP) { + int16_t field_id = readInt16(); + switch (field_id) { + case 1: + CHECK_THRIFT_TYPE(field_type, TTYPE_LIST, "data"); + CHECK_THRIFT_TYPE(readByte(), TensorTraits::ThriftType, "data_type"); + length = getRawBuffer(&data); + raw_length = length * sizeof(T); + break; + case 2: + CHECK_THRIFT_TYPE(field_type, TTYPE_LIST, "shape"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_I64, "shape_type"); + shape = readShape(); + break; + default: + throw ThriftInvalidField(field_id, "TensorRecordReader::readTypedTensor"); + } + } + + // data is required + if (data == nullptr) { + throw twml::Error(TWML_ERR_THRIFT, "data field not found for TypedTensor"); + } + + // shape is optional + if (shape.size() == 0) { + shape.push_back((uint64_t)length); + } + + // TODO: Try avoiding stride calculation + std::vector strides = calcStrides(shape); + // FIXME: Try to use const void * in Tensors. + return RawTensor(const_cast(static_cast(data)), + shape, strides, (twml_type)TensorTraits::TwmlType, true, raw_length); +} + +RawTensor TensorRecordReader::readRawTypedTensor() { + std::vector shape; + const uint8_t *data = nullptr; + twml_type type = TWML_TYPE_UNKNOWN; + uint64_t raw_length = 0; + uint8_t field_type = TTYPE_STOP; + + while ((field_type = readByte()) != TTYPE_STOP) { + int16_t field_id = readInt16(); + switch (field_id) { + case 1: + CHECK_THRIFT_TYPE(field_type, TTYPE_I32, "DataType"); + type = getTwmlType(readInt32()); + break; + case 2: + CHECK_THRIFT_TYPE(field_type, TTYPE_STRING, "content"); + raw_length = getRawBuffer(&data); + break; + case 3: + CHECK_THRIFT_TYPE(field_type, TTYPE_LIST, "shape"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_I64, "shape_type"); + shape = readShape(); + break; + default: + throw ThriftInvalidField(field_id, "TensorRecordReader::readRawTypedTensor"); + } + } + + // data type is required + if (type == TWML_TYPE_UNKNOWN) { + throw twml::Error(TWML_ERR_THRIFT, "DataType is a required field for RawTypedTensor"); + } + + // data is required + if (data == nullptr) { + throw twml::Error(TWML_ERR_THRIFT, "content is a required field for RawTypedTensor"); + } + + // shape is optional in the thrift file, but it is really required for string types. + if (shape.size() == 0) { + if (type == TWML_TYPE_STRING) { + throw twml::Error(TWML_ERR_THRIFT, "shape required for string types in RawTypedTensor"); + } + shape.push_back((uint64_t)(raw_length / getSizeOf(type))); + } + + // TODO: Try avoiding stride calculation + std::vector strides = calcStrides(shape); + // FIXME: Try to use const void * data inside Tensors. + return RawTensor(const_cast(static_cast(data)), + shape, strides, type, false, raw_length); +} + +RawTensor TensorRecordReader::readStringTensor() { + std::vector shape; + int32_t length = 0; + const uint8_t *data = nullptr; + uint64_t raw_length = 0; + uint8_t field_type = TTYPE_STOP; + const uint8_t *dummy = nullptr; + + while ((field_type = readByte()) != TTYPE_STOP) { + int16_t field_id = readInt16(); + switch (field_id) { + case 1: + CHECK_THRIFT_TYPE(field_type, TTYPE_LIST, "data"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_STRING, "data_type"); + length = readInt32(); + // Store the current location of the byte stream. + // Use this at to "deocde strings" at a later point. + data = getBuffer(); + for (int32_t i = 0; i < length; i++) { + // Skip reading the strings + getRawBuffer(&dummy); + } + raw_length = length; + break; + case 2: + CHECK_THRIFT_TYPE(field_type, TTYPE_LIST, "shape"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_I64, "shape_type"); + shape = readShape(); + break; + default: + throw ThriftInvalidField(field_id, "TensorRecordReader::readTypedTensor"); + } + } + + // data is required + if (data == nullptr) { + throw twml::Error(TWML_ERR_THRIFT, "data field not found for TypedTensor"); + } + + // shape is optional + if (shape.size() == 0) { + shape.push_back((uint64_t)length); + } + + // TODO: Try avoiding stride calculation + std::vector strides = calcStrides(shape); + // FIXME: Try to use const void * in Tensors. + return RawTensor(const_cast(static_cast(data)), + shape, strides, TWML_TYPE_UINT8, false, raw_length); +} + +RawTensor TensorRecordReader::readGeneralTensor() { + // No loop is required because GeneralTensor is union. It is going to contain one field only. + // All the fields are structs + CHECK_THRIFT_TYPE(readByte(), TTYPE_STRUCT, "type"); + int16_t field_id = readInt16(); + RawTensor output; + + switch (field_id) { + case GT_RAW: + output = readRawTypedTensor(); + break; + case GT_STRING: + output = readStringTensor(); + break; + case GT_INT32: + output = readTypedTensor(); + break; + case GT_INT64: + output = readTypedTensor(); + break; + case GT_FLOAT: + case GT_DOUBLE: + // Store both FloatTensor and DoubleTensor as double tensor as both are list of doubles. + output = readTypedTensor(); + break; + case GT_BOOL: + output = readTypedTensor(); + break; + default: + throw ThriftInvalidField(field_id, "TensorRecordReader::readGeneralTensor()"); + } + + CHECK_THRIFT_TYPE(readByte(), TTYPE_STOP, "stop"); + return output; +} + +RawSparseTensor TensorRecordReader::readCOOSparseTensor() { + std::vector shape; + uint8_t field_type = TTYPE_STOP; + RawTensor indices, values; + + while ((field_type = readByte()) != TTYPE_STOP) { + int16_t field_id = readInt16(); + switch (field_id) { + case 1: + CHECK_THRIFT_TYPE(field_type, TTYPE_LIST, "shape"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_I64, "shape_type"); + shape = readShape(); + break; + case 2: + indices = readTypedTensor(); + break; + case 3: + values = readGeneralTensor(); + break; + default: + throw twml::Error(TWML_ERR_THRIFT, "Invalid field when deocidng COOSparseTensor"); + } + } + + return RawSparseTensor(indices, values, shape); +} + +void TensorRecordReader::readTensor(const int feature_type, TensorRecord *record) { + CHECK_THRIFT_TYPE(feature_type, TTYPE_MAP, "type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_I64, "key_type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_STRUCT, "value_type"); + + int32_t length = readInt32(); + for (int32_t i = 0; i < length; i++) { + int64_t id = readInt64(); + record->m_tensors.emplace(id, readGeneralTensor()); + } +} + +void TensorRecordReader::readSparseTensor(const int feature_type, TensorRecord *record) { + CHECK_THRIFT_TYPE(feature_type, TTYPE_MAP, "type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_I64, "key_type"); + CHECK_THRIFT_TYPE(readByte(), TTYPE_STRUCT, "value_type"); + + int32_t length = readInt32(); + for (int32_t i = 0; i < length; i++) { + int64_t id = readInt64(); + + // No loop is required because SparseTensor is union. It is going to contain one field only. + // All the fields are structs + CHECK_THRIFT_TYPE(readByte(), TTYPE_STRUCT, "field"); + int16_t field_id = readInt16(); + RawSparseTensor output; + + // Only COOSparsetensor is supported. + switch (field_id) { + case SP_COO: + output = readCOOSparseTensor(); + break; + default: + throw ThriftInvalidField(field_id, "TensorRecordReader::readSparseTensor()"); + } + + // Read the last byte of the struct. + CHECK_THRIFT_TYPE(readByte(), TTYPE_STOP, "stop"); + + // Add to the map. + record->m_sparse_tensors.emplace(id, output); + } +} + +} // namespace twml diff --git a/twml/libtwml/src/lib/TensorRecordWriter.cpp b/twml/libtwml/src/lib/TensorRecordWriter.cpp new file mode 100644 index 000000000..b1fe98e64 --- /dev/null +++ b/twml/libtwml/src/lib/TensorRecordWriter.cpp @@ -0,0 +1,162 @@ +#include "internal/error.h" +#include "internal/thrift.h" + +#include +#include +#include +#include + +using namespace twml::io; + +namespace twml { + +static int32_t getRawThriftType(twml_type dtype) { + // convert twml enum to tensor.thrift enum + switch (dtype) { + case TWML_TYPE_FLOAT: + return DATA_TYPE_FLOAT; + case TWML_TYPE_DOUBLE: + return DATA_TYPE_DOUBLE; + case TWML_TYPE_INT64: + return DATA_TYPE_INT64; + case TWML_TYPE_INT32: + return DATA_TYPE_INT32; + case TWML_TYPE_UINT8: + return DATA_TYPE_UINT8; + case TWML_TYPE_STRING: + return DATA_TYPE_STRING; + case TWML_TYPE_BOOL: + return DATA_TYPE_BOOL; + default: + throw IOError(IOError::UNSUPPORTED_OUTPUT_TYPE); + } +} + +void TensorRecordWriter::writeTensor(const RawTensor &tensor) { + if (tensor.getType() == TWML_TYPE_INT32) { + m_thrift_writer.writeStructFieldHeader(TTYPE_STRUCT, GT_INT32); + m_thrift_writer.writeStructFieldHeader(TTYPE_LIST, 1); + m_thrift_writer.writeListHeader(TTYPE_I32, tensor.getNumElements()); + + const int32_t *data = tensor.getData(); + + for (uint64_t i = 0; i < tensor.getNumElements(); i++) + m_thrift_writer.writeInt32(data[i]); + + } else if (tensor.getType() == TWML_TYPE_INT64) { + m_thrift_writer.writeStructFieldHeader(TTYPE_STRUCT, GT_INT64); + m_thrift_writer.writeStructFieldHeader(TTYPE_LIST, 1); + m_thrift_writer.writeListHeader(TTYPE_I64, tensor.getNumElements()); + + const int64_t *data = tensor.getData(); + + for (uint64_t i = 0; i < tensor.getNumElements(); i++) + m_thrift_writer.writeInt64(data[i]); + + } else if (tensor.getType() == TWML_TYPE_FLOAT) { + m_thrift_writer.writeStructFieldHeader(TTYPE_STRUCT, GT_FLOAT); + m_thrift_writer.writeStructFieldHeader(TTYPE_LIST, 1); + m_thrift_writer.writeListHeader(TTYPE_DOUBLE, tensor.getNumElements()); + + const float *data = tensor.getData(); + + for (uint64_t i = 0; i < tensor.getNumElements(); i++) + m_thrift_writer.writeDouble(static_cast(data[i])); + + } else if (tensor.getType() == TWML_TYPE_DOUBLE) { + m_thrift_writer.writeStructFieldHeader(TTYPE_STRUCT, GT_DOUBLE); + m_thrift_writer.writeStructFieldHeader(TTYPE_LIST, 1); + m_thrift_writer.writeListHeader(TTYPE_DOUBLE, tensor.getNumElements()); + + const double *data = tensor.getData(); + + for (uint64_t i = 0; i < tensor.getNumElements(); i++) + m_thrift_writer.writeDouble(data[i]); + + } else if (tensor.getType() == TWML_TYPE_STRING) { + m_thrift_writer.writeStructFieldHeader(TTYPE_STRUCT, GT_STRING); + m_thrift_writer.writeStructFieldHeader(TTYPE_LIST, 1); + m_thrift_writer.writeListHeader(TTYPE_STRING, tensor.getNumElements()); + + const std::string *data = tensor.getData(); + + for (uint64_t i = 0; i < tensor.getNumElements(); i++) + m_thrift_writer.writeString(data[i]); + + } else if (tensor.getType() == TWML_TYPE_BOOL) { + m_thrift_writer.writeStructFieldHeader(TTYPE_STRUCT, GT_BOOL); + m_thrift_writer.writeStructFieldHeader(TTYPE_LIST, 1); + m_thrift_writer.writeListHeader(TTYPE_BOOL, tensor.getNumElements()); + + const bool *data = tensor.getData(); + + for (uint64_t i = 0; i < tensor.getNumElements(); i++) + m_thrift_writer.writeBool(data[i]); + + } else { + throw IOError(IOError::UNSUPPORTED_OUTPUT_TYPE); + } + + // write tensor shape field + m_thrift_writer.writeStructFieldHeader(TTYPE_LIST, 2); + m_thrift_writer.writeListHeader(TTYPE_I64, tensor.getNumDims()); + + for (uint64_t i = 0; i < tensor.getNumDims(); i++) + m_thrift_writer.writeInt64(tensor.getDim(i)); + + m_thrift_writer.writeStructStop(); + m_thrift_writer.writeStructStop(); +} + +void TensorRecordWriter::writeRawTensor(const RawTensor &tensor) { + m_thrift_writer.writeStructFieldHeader(TTYPE_STRUCT, GT_RAW); + + // dataType field + m_thrift_writer.writeStructFieldHeader(TTYPE_I32, 1); + m_thrift_writer.writeInt32(getRawThriftType(tensor.getType())); + + // content field + uint64_t type_size = getSizeOf(tensor.getType()); + m_thrift_writer.writeStructFieldHeader(TTYPE_STRING, 2); + const uint8_t *data = reinterpret_cast(tensor.getData()); + m_thrift_writer.writeBinary(data, tensor.getNumElements() * type_size); + + // shape field + m_thrift_writer.writeStructFieldHeader(TTYPE_LIST, 3); + m_thrift_writer.writeListHeader(TTYPE_I64, tensor.getNumDims()); + + for (uint64_t i = 0; i < tensor.getNumDims(); i++) + m_thrift_writer.writeInt64(tensor.getDim(i)); + + m_thrift_writer.writeStructStop(); + m_thrift_writer.writeStructStop(); +} + +TWMLAPI uint32_t TensorRecordWriter::getRecordsWritten() { + return m_records_written; +} + +// Caller (usually DataRecordWriter) must precede with struct header field +// like thrift_writer.writeStructFieldHeader(TTYPE_MAP, DR_GENERAL_TENSOR) +TWMLAPI uint64_t TensorRecordWriter::write(twml::TensorRecord &record) { + uint64_t bytes_written_before = m_thrift_writer.getBytesWritten(); + + m_thrift_writer.writeMapHeader(TTYPE_I64, TTYPE_STRUCT, record.getRawTensors().size()); + + for (auto id_tensor_pairs : record.getRawTensors()) { + m_thrift_writer.writeInt64(id_tensor_pairs.first); + + // all tensors written as RawTensor Thrift except for StringTensors + // this avoids the overhead of converting little endian to big endian + if (id_tensor_pairs.second.getType() == TWML_TYPE_STRING) + writeTensor(id_tensor_pairs.second); + else + writeRawTensor(id_tensor_pairs.second); + } + + m_records_written++; + + return m_thrift_writer.getBytesWritten() - bytes_written_before; +} + +} // namespace twml diff --git a/twml/libtwml/src/lib/ThriftReader.cpp b/twml/libtwml/src/lib/ThriftReader.cpp new file mode 100644 index 000000000..bceb74c13 --- /dev/null +++ b/twml/libtwml/src/lib/ThriftReader.cpp @@ -0,0 +1,33 @@ +#include "internal/endianutils.h" + +#include +#include + +#include + +namespace twml { + +uint8_t ThriftReader::readByte() { + return readDirect(); +} + +int16_t ThriftReader::readInt16() { + return betoh16(readDirect()); +} + +int32_t ThriftReader::readInt32() { + return betoh32(readDirect()); +} + +int64_t ThriftReader::readInt64() { + return betoh64(readDirect()); +} + +double ThriftReader::readDouble() { + double val; + int64_t *val_proxy = reinterpret_cast(&val); + *val_proxy = readInt64(); + return val; +} + +} // namespace twml diff --git a/twml/libtwml/src/lib/ThriftWriter.cpp b/twml/libtwml/src/lib/ThriftWriter.cpp new file mode 100644 index 000000000..4f298a154 --- /dev/null +++ b/twml/libtwml/src/lib/ThriftWriter.cpp @@ -0,0 +1,91 @@ +#include "internal/endianutils.h" +#include "internal/error.h" +#include "internal/thrift.h" + +#include +#include +#include + +#include + +using namespace twml::io; + +namespace twml { + +template inline +uint64_t ThriftWriter::write(T val) { + if (!m_dry_run) { + if (m_bytes_written + sizeof(T) > m_buffer_size) + throw IOError(IOError::DESTINATION_LARGER_THAN_CAPACITY); + memcpy(m_buffer, &val, sizeof(T)); + m_buffer += sizeof(T); + } + m_bytes_written += sizeof(T); + return sizeof(T); +} + +TWMLAPI uint64_t ThriftWriter::getBytesWritten() { + return m_bytes_written; +} + +TWMLAPI uint64_t ThriftWriter::writeStructFieldHeader(int8_t field_type, int16_t field_id) { + return writeInt8(field_type) + writeInt16(field_id); +} + +TWMLAPI uint64_t ThriftWriter::writeStructStop() { + return writeInt8(static_cast(TTYPE_STOP)); +} + +TWMLAPI uint64_t ThriftWriter::writeListHeader(int8_t element_type, int32_t num_elems) { + return writeInt8(element_type) + writeInt32(num_elems); +} + +TWMLAPI uint64_t ThriftWriter::writeMapHeader(int8_t key_type, int8_t val_type, int32_t num_elems) { + return writeInt8(key_type) + writeInt8(val_type) + writeInt32(num_elems); +} + +TWMLAPI uint64_t ThriftWriter::writeDouble(double val) { + int64_t bin_value; + memcpy(&bin_value, &val, sizeof(int64_t)); + return writeInt64(bin_value); +} + +TWMLAPI uint64_t ThriftWriter::writeInt8(int8_t val) { + return write(val); +} + +TWMLAPI uint64_t ThriftWriter::writeInt16(int16_t val) { + return write(betoh16(val)); +} + +TWMLAPI uint64_t ThriftWriter::writeInt32(int32_t val) { + return write(betoh32(val)); +} + +TWMLAPI uint64_t ThriftWriter::writeInt64(int64_t val) { + return write(betoh64(val)); +} + +TWMLAPI uint64_t ThriftWriter::writeBinary(const uint8_t *bytes, int32_t num_bytes) { + writeInt32(num_bytes); + + if (!m_dry_run) { + if (m_bytes_written + num_bytes > m_buffer_size) + throw IOError(IOError::DESTINATION_LARGER_THAN_CAPACITY); + memcpy(m_buffer, bytes, num_bytes); + m_buffer += num_bytes; + } + m_bytes_written += num_bytes; + + return 4 + num_bytes; +} + +TWMLAPI uint64_t ThriftWriter::writeString(std::string str) { + return writeBinary(reinterpret_cast(str.data()), str.length()); +} + +TWMLAPI uint64_t ThriftWriter::writeBool(bool val) { + return write(val); +} + +} // namespace twml diff --git a/twml/libtwml/src/lib/discretizer_impl.cpp b/twml/libtwml/src/lib/discretizer_impl.cpp new file mode 100644 index 000000000..3f161341e --- /dev/null +++ b/twml/libtwml/src/lib/discretizer_impl.cpp @@ -0,0 +1,167 @@ +#include "internal/interpolate.h" +#include "internal/error.h" +#include +#include + +namespace twml { + // it is assumed that start_compute and end_compute are valid + template + void discretizerInfer(Tensor &output_keys, + Tensor &output_vals, + const Tensor &input_ids, + const Tensor &input_vals, + const Tensor &bin_ids, + const Tensor &bin_vals, + const Tensor &feature_offsets, + int output_bits, + const Map &ID_to_index, + int64_t start_compute, + int64_t end_compute, + int64_t output_start) { + auto out_keysData = output_keys.getData(); + auto out_valsData = output_vals.getData(); + uint64_t out_keysStride = output_keys.getStride(0); + uint64_t out_valsStride = output_vals.getStride(0); + + auto in_idsData = input_ids.getData(); + auto in_valsData = input_vals.getData(); + uint64_t in_idsStride = input_ids.getStride(0); + uint64_t in_valsStride = input_vals.getStride(0); + + auto xsData = bin_vals.getData(); + auto ysData = bin_ids.getData(); + uint64_t xsStride = bin_vals.getStride(0); + uint64_t ysStride = bin_ids.getStride(0); + + auto offsetData = feature_offsets.getData(); + + uint64_t total_bins = bin_ids.getNumElements(); + uint64_t fsize = feature_offsets.getNumElements(); + + uint64_t output_size = (1 << output_bits); + + for (uint64_t i = start_compute; i < end_compute; i++) { + int64_t feature_ID = in_idsData[i * in_idsStride]; + T val = in_valsData[i * in_valsStride]; + + auto iter = ID_to_index.find(feature_ID); + if (iter == ID_to_index.end()) { + // feature not calibrated + // modulo add operation for new key from feature ID + int64_t ikey = feature_ID % (output_size - total_bins) + total_bins; + out_keysData[(i + output_start - start_compute) * out_keysStride] = ikey; + out_valsData[(i + output_start - start_compute) * out_valsStride] = val; + continue; + } + + int64_t ikey = iter->second; + + // Perform interpolation + uint64_t offset = offsetData[ikey]; + uint64_t next_offset = (ikey == (int64_t)(fsize - 1)) ? total_bins : offsetData[ikey + 1]; + uint64_t mainSize = next_offset - offset; + + const T *lxsData = xsData + offset; + const int64_t *lysData = ysData + offset; + int64_t okey; + okey = interpolation(lxsData, xsStride, + lysData, ysStride, + val, mainSize, + NEAREST, 0); + out_keysData[(i + output_start - start_compute) * out_keysStride] = okey; + out_valsData[(i + output_start - start_compute) * out_valsStride] = 1; + } + } + + void discretizerInfer(Tensor &output_keys, + Tensor &output_vals, + const Tensor &input_ids, + const Tensor &input_vals, + const Tensor &bin_ids, + const Tensor &bin_vals, + const Tensor &feature_offsets, + int output_bits, + const Map &ID_to_index, + int start_compute, + int end_compute, + int output_start) { + if (input_ids.getType() != TWML_TYPE_INT64) { + throw twml::Error(TWML_ERR_TYPE, "input_ids must be a Long Tensor"); + } + + if (output_keys.getType() != TWML_TYPE_INT64) { + throw twml::Error(TWML_ERR_TYPE, "output_keys must be a Long Tensor"); + } + + if (bin_ids.getType() != TWML_TYPE_INT64) { + throw twml::Error(TWML_ERR_TYPE, "bin_ids must be a Long Tensor"); + } + + if (feature_offsets.getType() != TWML_TYPE_INT64) { + throw twml::Error(TWML_ERR_TYPE, "bin_ids must be a Long Tensor"); + } + + if (input_vals.getType() != bin_vals.getType()) { + throw twml::Error(TWML_ERR_TYPE, + "Data type of input_vals does not match type of bin_vals"); + } + + if (bin_vals.getNumDims() != 1) { + throw twml::Error(TWML_ERR_SIZE, + "bin_vals must be 1 Dimensional"); + } + + if (bin_ids.getNumDims() != 1) { + throw twml::Error(TWML_ERR_SIZE, + "bin_ids must be 1 Dimensional"); + } + + if (bin_vals.getNumElements() != bin_ids.getNumElements()) { + throw twml::Error(TWML_ERR_SIZE, + "Dimensions of bin_vals and bin_ids do not match"); + } + + if (feature_offsets.getStride(0) != 1) { + throw twml::Error(TWML_ERR_SIZE, + "feature_offsets must be contiguous"); + } + + uint64_t size = input_ids.getDim(0); + if (end_compute == -1) { + end_compute = size; + } + + if (start_compute < 0 || start_compute >= size) { + throw twml::Error(TWML_ERR_SIZE, + "start_compute out of range"); + } + + if (end_compute < -1 || end_compute > size) { + throw twml::Error(TWML_ERR_SIZE, + "end_compute out of range"); + } + + if (start_compute > end_compute && end_compute != -1) { + throw twml::Error(TWML_ERR_SIZE, + "must have start_compute <= end_compute, or end_compute==-1"); + } + + switch (input_vals.getType()) { + case TWML_TYPE_FLOAT: + twml::discretizerInfer(output_keys, output_vals, + input_ids, input_vals, + bin_ids, bin_vals, feature_offsets, output_bits, ID_to_index, + start_compute, end_compute, output_start); + break; + case TWML_TYPE_DOUBLE: + twml::discretizerInfer(output_keys, output_vals, + input_ids, input_vals, + bin_ids, bin_vals, feature_offsets, output_bits, ID_to_index, + start_compute, end_compute, output_start); + break; + default: + throw twml::Error(TWML_ERR_TYPE, + "Unsupported datatype for discretizerInfer"); + } + } +} // namespace twml diff --git a/twml/libtwml/src/lib/functions.cpp b/twml/libtwml/src/lib/functions.cpp new file mode 100644 index 000000000..b7af3c0ac --- /dev/null +++ b/twml/libtwml/src/lib/functions.cpp @@ -0,0 +1,158 @@ +#include "internal/error.h" +#include "internal/murmur_hash3.h" +#include "internal/utf_converter.h" +#include +#include +#include + +namespace twml { + + template + void add1(Tensor &output, const Tensor input) { + T *odata = output.getData(); + const T *idata = input.getData(); + const uint64_t num_elements = input.getNumElements(); + + for (uint64_t i = 0; i < num_elements; i++) { + odata[i] = idata[i] + 1; + } + } + + template + void copy(Tensor &output, const Tensor input) { + T *odata = output.getData(); + const T *idata = input.getData(); + const uint64_t num_elements = input.getNumElements(); + + for (uint64_t i = 0; i < num_elements; i++) { + odata[i] = idata[i]; + } + } + + void add1(Tensor &output, const Tensor input) { + auto type = input.getType(); + if (output.getType() != type) { + throw twml::Error(TWML_ERR_TYPE, "Output type does not match input type"); + } + + if (output.getNumElements() != input.getNumElements()) { + throw twml::Error(TWML_ERR_SIZE, "Output size does not match input size"); + } + + // TODO: Implement an easier dispatch function + switch (type) { + case TWML_TYPE_FLOAT: + twml::add1(output, input); + break; + case TWML_TYPE_DOUBLE: + twml::add1(output, input); + break; + default: + throw twml::Error(TWML_ERR_TYPE, "add1 only supports float and double tensors"); + } + } + + void copy(Tensor &output, const Tensor input) { + auto type = input.getType(); + if (output.getType() != type) { + throw twml::Error(TWML_ERR_TYPE, "Output type does not match input type"); + } + + if (output.getNumElements() != input.getNumElements()) { + throw twml::Error(TWML_ERR_SIZE, "Output size does not match input size"); + } + + // TODO: Implement an easier dispatch function + switch (type) { + case TWML_TYPE_FLOAT: + twml::copy(output, input); + break; + case TWML_TYPE_DOUBLE: + twml::copy(output, input); + break; + default: + throw twml::Error(TWML_ERR_TYPE, "copy only supports float and double tensors"); + } + } + + int64_t featureId(const std::string &feature) { + const char *str = feature.c_str(); + uint64_t len = feature.size(); + int64_t id = 0; + TWML_CHECK(twml_get_feature_id(&id, len, str), "Error getting featureId"); + return id; + } +} // namespace twml + +twml_err twml_add1(twml_tensor output, const twml_tensor input) { + HANDLE_EXCEPTIONS( + auto out = twml::getTensor(output); + auto in = twml::getConstTensor(input); + twml::add1(*out, *in);); + return TWML_ERR_NONE; +} + +twml_err twml_copy(twml_tensor output, const twml_tensor input) { + HANDLE_EXCEPTIONS( + auto out = twml::getTensor(output); + auto in = twml::getConstTensor(input); + twml::copy(*out, *in);); + return TWML_ERR_NONE; +} + +inline twml_err twml_get_feature_id_internal(int64_t *result, + uint64_t out_size, uint16_t *out, + uint64_t out2_size, uint16_t *out2, + const uint64_t len, const char *str) { + uint64_t k = 0; + for (uint64_t i = 0; i < len; i++) { + if (str[i] == '#') { + k = i; + break; + } + } + + uint8_t hash[16]; + if (k != 0) { + ssize_t n = utf8_to_utf16((const uint8_t *) str, k, out, out_size); + if (n < 0) throw std::invalid_argument("error while converting from utf8 to utf16"); + + MurmurHash3_x64_128(out, n * sizeof(uint16_t), 0, out2); + n = utf8_to_utf16((const uint8_t *) (str + k + 1), len - k - 1, &out2[4], out2_size - 8); + if (n < 0) throw std::invalid_argument("error while converting from utf8 to utf16"); + + MurmurHash3_x64_128(out2, (n * sizeof(uint16_t)) + 8, 0, hash); + } else { + ssize_t n = utf8_to_utf16((const uint8_t *)str, len, out, out_size); + if (n < 0) throw std::invalid_argument("error while converting from utf8 to utf16"); + MurmurHash3_x64_128(out, n * sizeof(uint16_t), 0, hash); + } + int64_t id; + memcpy(&id, hash, sizeof(int64_t)); + *result = id; + + return TWML_ERR_NONE; +} + +static const int UTF16_STR_MAX_SIZE = 1024; + +twml_err twml_get_feature_id(int64_t *result, const uint64_t len, const char *str) { + try { + uint16_t out[UTF16_STR_MAX_SIZE]; + uint16_t out2[UTF16_STR_MAX_SIZE]; + return twml_get_feature_id_internal(result, + UTF16_STR_MAX_SIZE, out, + UTF16_STR_MAX_SIZE, out2, + len, str); + } catch(const std::invalid_argument &ex) { + // If the space on the stack is not enough, try using the heap. + // len + 1 is needed because a null terminating character is added at the end. + std::vector out(len + 1); + std::vector out2(len + 1); + return twml_get_feature_id_internal(result, + len + 1, out.data(), + len + 1, out2.data(), + len, str); + + } +} diff --git a/twml/libtwml/src/lib/hashing_discretizer_impl.cpp b/twml/libtwml/src/lib/hashing_discretizer_impl.cpp new file mode 100644 index 000000000..166242ffb --- /dev/null +++ b/twml/libtwml/src/lib/hashing_discretizer_impl.cpp @@ -0,0 +1,241 @@ +#include "internal/linear_search.h" +#include "internal/error.h" +#include +#include +#include + +namespace twml { + template + static int64_t lower_bound_search(const Tx *data, const Tx val, const int64_t buf_size) { + auto index_temp = std::lower_bound(data, data + buf_size, val); + return static_cast(index_temp - data); + } + + template + static int64_t upper_bound_search(const Tx *data, const Tx val, const int64_t buf_size) { + auto index_temp = std::upper_bound(data, data + buf_size, val); + return static_cast(index_temp - data); + } + + template + using search_method = int64_t (*)(const Tx *, const Tx, const int64_t); + + typedef uint64_t (*hash_signature)(uint64_t, int64_t, uint64_t); + + // uint64_t integer_multiplicative_hashing() + // + // A function to hash discretized feature_ids into one of 2**output_bits buckets. + // This function hashes the feature_ids to achieve a uniform distribution of + // IDs, so the hashed IDs are with high probability far apart + // Then, bucket_indices can simply be added, resulting in unique new IDs with high probability + // We integer hash again to again spread out the new IDs + // Finally we take the upper + // Required args: + // feature_id: + // The feature id of the feature to be hashed. + // bucket_index: + // The bucket index of the discretized feature value + // output_bits: + // The number of bits of output space for the features to be hashed into. + // + // Note - feature_ids may have arbitrary distribution within int32s + // Note - 64 bit feature_ids can be processed with this, but the upper + // 32 bits have no effect on the output + // e.g. all feature ids 0 through 255 exist in movie-lens. + // this hashing constant is good for 32 LSBs. will use N=32. (can use N<32 also) + // this hashing constant is co-prime with 2**32, therefore we have that + // a != b, a and b in [0,2**32) + // implies + // f(a) != f(b) where f(x) = (hashing_constant * x) % (2**32) + // note that we are mostly ignoring the upper 32 bits, using modulo 2**32 arithmetic + uint64_t integer_multiplicative_hashing(uint64_t feature_id, + int64_t bucket_index, + uint64_t output_bits) { + // possibly use 14695981039346656037 for 64 bit unsigned?? + // = 20921 * 465383 * 1509404459 + // alternatively, 14695981039346656039 is prime + // We would also need to use N = 64 + const uint64_t hashing_constant = 2654435761; + const uint64_t N = 32; + // hash once to prevent problems from anomalous input id distributions + feature_id *= hashing_constant; + feature_id += bucket_index; + // this hash enables the following right shift operation + // without losing the bucket information (lower bits) + feature_id *= hashing_constant; + // output size is a power of 2 + feature_id >>= N - output_bits; + uint64_t mask = (1 << output_bits) - 1; + return mask & feature_id; + } + + uint64_t integer64_multiplicative_hashing(uint64_t feature_id, + int64_t bucket_index, + uint64_t output_bits) { + const uint64_t hashing_constant = 14695981039346656039UL; + const uint64_t N = 64; + // hash once to prevent problems from anomalous input id distributions + feature_id *= hashing_constant; + feature_id += bucket_index; + // this hash enables the following right shift operation + // without losing the bucket information (lower bits) + feature_id *= hashing_constant; + // output size is a power of 2 + feature_id >>= N - output_bits; + uint64_t mask = (1 << output_bits) - 1; + return mask & feature_id; + } + + int64_t option_bits(int64_t options, int64_t high, int64_t low) { + options >>= low; + options &= (1 << (high - low + 1)) - 1; + return options; + } + + // it is assumed that start_compute and end_compute are valid + template + void hashDiscretizerInfer(Tensor &output_keys, + Tensor &output_vals, + const Tensor &input_ids, + const Tensor &input_vals, + const Tensor &bin_vals, + int output_bits, + const Map &ID_to_index, + int64_t start_compute, + int64_t end_compute, + int64_t n_bin, + int64_t options) { + auto output_keys_data = output_keys.getData(); + auto output_vals_data = output_vals.getData(); + + auto input_ids_data = input_ids.getData(); + auto input_vals_data = input_vals.getData(); + + auto bin_vals_data = bin_vals.getData(); + + // The function pointer implementation removes the option_bits + // function call (might be inlined) and corresponding branch from + // the hot loop, but it prevents inlining these functions, so + // there will be function call overhead. Uncertain which would + // be faster, testing needed. Also, code optimizers do weird things... + hash_signature hash_fn = integer_multiplicative_hashing; + switch (option_bits(options, 4, 2)) { + case 0: + hash_fn = integer_multiplicative_hashing; + break; + case 1: + hash_fn = integer64_multiplicative_hashing; + break; + default: + hash_fn = integer_multiplicative_hashing; + } + + search_method search_fn = lower_bound_search; + switch (option_bits(options, 1, 0)) { + case 0: + search_fn = lower_bound_search; + break; + case 1: + search_fn = linear_search; + break; + case 2: + search_fn = upper_bound_search; + break; + default: + search_fn = lower_bound_search; + } + + for (uint64_t i = start_compute; i < end_compute; i++) { + int64_t id = input_ids_data[i]; + T val = input_vals_data[i]; + + auto iter = ID_to_index.find(id); + if (iter != ID_to_index.end()) { + int64_t feature_idx = iter->second; + const T *bin_vals_start = bin_vals_data + feature_idx * n_bin; + int64_t out_bin_idx = search_fn(bin_vals_start, val, n_bin); + output_keys_data[i] = hash_fn(id, out_bin_idx, output_bits); + output_vals_data[i] = 1; + } else { + // feature not calibrated + output_keys_data[i] = id & ((1 << output_bits) - 1); + output_vals_data[i] = val; + } + } + } + + void hashDiscretizerInfer(Tensor &output_keys, + Tensor &output_vals, + const Tensor &input_ids, + const Tensor &input_vals, + int n_bin, + const Tensor &bin_vals, + int output_bits, + const Map &ID_to_index, + int start_compute, + int end_compute, + int64_t options) { + if (input_ids.getType() != TWML_TYPE_INT64) { + throw twml::Error(TWML_ERR_TYPE, "input_ids must be a Long Tensor"); + } + + if (output_keys.getType() != TWML_TYPE_INT64) { + throw twml::Error(TWML_ERR_TYPE, "output_keys must be a Long Tensor"); + } + + if (input_vals.getType() != bin_vals.getType()) { + throw twml::Error(TWML_ERR_TYPE, + "Data type of input_vals does not match type of bin_vals"); + } + + if (bin_vals.getNumDims() != 1) { + throw twml::Error(TWML_ERR_SIZE, + "bin_vals must be 1 Dimensional"); + } + + uint64_t size = input_ids.getDim(0); + if (end_compute == -1) { + end_compute = size; + } + + if (start_compute < 0 || start_compute >= size) { + throw twml::Error(TWML_ERR_SIZE, + "start_compute out of range"); + } + + if (end_compute < -1 || end_compute > size) { + throw twml::Error(TWML_ERR_SIZE, + "end_compute out of range"); + } + + if (start_compute > end_compute && end_compute != -1) { + throw twml::Error(TWML_ERR_SIZE, + "must have start_compute <= end_compute, or end_compute==-1"); + } + + if (output_keys.getStride(0) != 1 || output_vals.getStride(0) != 1 || + input_ids.getStride(0) != 1 || input_vals.getStride(0) != 1 || + bin_vals.getStride(0) != 1) { + throw twml::Error(TWML_ERR_SIZE, + "All Strides must be 1."); + } + + switch (input_vals.getType()) { + case TWML_TYPE_FLOAT: + twml::hashDiscretizerInfer(output_keys, output_vals, + input_ids, input_vals, + bin_vals, output_bits, ID_to_index, + start_compute, end_compute, n_bin, options); + break; + case TWML_TYPE_DOUBLE: + twml::hashDiscretizerInfer(output_keys, output_vals, + input_ids, input_vals, + bin_vals, output_bits, ID_to_index, + start_compute, end_compute, n_bin, options); + break; + default: + throw twml::Error(TWML_ERR_TYPE, + "Unsupported datatype for hashDiscretizerInfer"); + } + } +} // namespace twml diff --git a/twml/libtwml/src/lib/internal/endianutils.h b/twml/libtwml/src/lib/internal/endianutils.h new file mode 100644 index 000000000..3b27797d7 --- /dev/null +++ b/twml/libtwml/src/lib/internal/endianutils.h @@ -0,0 +1,137 @@ +// +// endian_fix.h +// ImageCore +// +// For OSes that use glibc < 2.9 (like RHEL5) +// +#pragma once + +#ifdef __APPLE__ +#include +#define htobe16(x) OSSwapHostToBigInt16(x) +#define htole16(x) OSSwapHostToLittleInt16(x) +#define betoh16(x) OSSwapBigToHostInt16(x) +#define letoh16(x) OSSwapLittleToHostInt16(x) +#define htobe32(x) OSSwapHostToBigInt32(x) +#define htole32(x) OSSwapHostToLittleInt32(x) +#define betoh32(x) OSSwapBigToHostInt32(x) +#define letoh32(x) OSSwapLittleToHostInt32(x) +#define htobe64(x) OSSwapHostToBigInt64(x) +#define htole64(x) OSSwapHostToLittleInt64(x) +#define betoh64(x) OSSwapBigToHostInt64(x) +#define letoh64(x) OSSwapLittleToHostInt64(x) +#else +#include +#ifdef __USE_BSD +/* Conversion interfaces. */ +#include + +#if __BYTE_ORDER == __LITTLE_ENDIAN +#ifndef htobe16 +#define htobe16(x) __bswap_16(x) +#endif +#ifndef htole16 +#define htole16(x) (x) +#endif +#ifndef betoh16 +#define betoh16(x) __bswap_16(x) +#endif +#ifndef letoh16 +#define letoh16(x) (x) +#endif + +#ifndef htobe32 +#define htobe32(x) __bswap_32(x) +#endif +#ifndef htole32 +#define htole32(x) (x) +#endif +#ifndef betoh32 +#define betoh32(x) __bswap_32(x) +#endif +#ifndef letoh32 +#define letoh32(x) (x) +#endif + +#ifndef htobe64 +#define htobe64(x) __bswap_64(x) +#endif +#ifndef htole64 +#define htole64(x) (x) +#endif +#ifndef betoh64 +#define betoh64(x) __bswap_64(x) +#endif +#ifndef letoh64 +#define letoh64(x) (x) +#endif + +#else /* __BYTE_ORDER == __LITTLE_ENDIAN */ +#ifndef htobe16 +#define htobe16(x) (x) +#endif +#ifndef htole16 +#define htole16(x) __bswap_16(x) +#endif +#ifndef be16toh +#define be16toh(x) (x) +#endif +#ifndef le16toh +#define le16toh(x) __bswap_16(x) +#endif + +#ifndef htobe32 +#define htobe32(x) (x) +#endif +#ifndef htole32 +#define htole32(x) __bswap_32(x) +#endif +#ifndef betoh32 +#define betoh32(x) (x) +#endif +#ifndef letoh32 +#define letoh32(x) __bswap_32(x) +#endif + +#ifndef htobe64 +#define htobe64(x) (x) +#endif +#ifndef htole64 +#define htole64(x) __bswap_64(x) +#endif +#ifndef betoh64 +#define betoh64(x) (x) +#endif +#ifndef letoh64 +#define letoh64(x) __bswap_64(x) +#endif + +#endif /* __BYTE_ORDER == __LITTLE_ENDIAN */ + +#else /* __USE_BSD */ +#ifndef betoh16 +#define betoh16 be16toh +#endif + +#ifndef betoh32 +#define betoh32 be32toh +#endif + +#ifndef betoh64 +#define betoh64 be64toh +#endif + +#ifndef letoh16 +#define letoh16 le16toh +#endif + +#ifndef letoh32 +#define letoh32 le32toh +#endif + +#ifndef letoh64 +#define letoh64 le64toh +#endif + +#endif /* __USE_BSD */ +#endif /* __APPLE__ */ diff --git a/twml/libtwml/src/lib/internal/error.h b/twml/libtwml/src/lib/internal/error.h new file mode 100644 index 000000000..3d1bc5441 --- /dev/null +++ b/twml/libtwml/src/lib/internal/error.h @@ -0,0 +1,29 @@ +#pragma once +#include +#include + +#define HANDLE_EXCEPTIONS(fn) do { \ + try { \ + fn \ + } catch(const twml::Error &e) { \ + std::cerr << e.what() << std::endl; \ + return e.err(); \ + } catch(...) { \ + std::cerr << "Unknown error\n"; \ + return TWML_ERR_UNKNOWN; \ + } \ + } while(0) + +#define TWML_CHECK(fn, msg) do { \ + twml_err err = fn; \ + if (err == TWML_ERR_NONE) break; \ + throw twml::Error(err, msg); \ + } while(0) + + +#define CHECK_THRIFT_TYPE(real_type, expected_type, type) do { \ + int real_type_val = real_type; \ + if (real_type_val != expected_type) { \ + throw twml::ThriftInvalidType(real_type_val, __func__, type); \ + } \ + } while(0) diff --git a/twml/libtwml/src/lib/internal/interpolate.h b/twml/libtwml/src/lib/internal/interpolate.h new file mode 100644 index 000000000..3e1daf53e --- /dev/null +++ b/twml/libtwml/src/lib/internal/interpolate.h @@ -0,0 +1,74 @@ +#pragma once + +#ifdef __cplusplus +#include +namespace twml { + + enum InterpolationMode {LINEAR, NEAREST}; + + template + static Tx interpolation(const Tx *xsData, const int64_t xsStride, + const Ty *ysData, const int64_t ysStride, + const Tx val, const int64_t mainSize, + const InterpolationMode mode, + const int64_t lowest, + const bool return_local_index = false) { + int64_t left = 0; + int64_t right = mainSize-1; + + if (val <= xsData[0]) { + right = 0; + } else if (val >= xsData[right*xsStride]) { + left = right; + } else { + while (left < right) { + int64_t middle = (left+right)/2; + + if (middle < mainSize - 1 && + val >= xsData[middle*xsStride] && + val <= xsData[(middle+1)*xsStride]) { + left = middle; + right = middle + 1; + break; + } else if (val > xsData[middle*xsStride]) { + left = middle; + } else { + right = middle; + } + } + if (lowest) { + while (left > 0 && + val >= xsData[(left - 1) * xsStride] && + val == xsData[left * xsStride]) { + left--; + right--; + } + } + } + + Ty out = 0; + if (return_local_index) { + out = left; + } else if (mode == NEAREST) { + out = ysData[left*ysStride]; + } else { + int64_t leftys = left*ysStride; + int64_t rightys = right*ysStride; + int64_t leftxs = left*xsStride; + int64_t rightxs = right*xsStride; + if (right != left+1 || + xsData[leftxs] == xsData[rightxs]) { + out = ysData[leftys]; + } else { + Tx xLeft = xsData[leftxs]; + Tx xRight = xsData[rightxs]; + Tx yLeft = ysData[leftys]; + Tx ratio = (val - xLeft) / (xRight - xLeft); + out = ratio*(ysData[rightys] - yLeft) + yLeft; + } + } + return out; + } + +} // namespace twml +#endif diff --git a/twml/libtwml/src/lib/internal/khash.h b/twml/libtwml/src/lib/internal/khash.h new file mode 100644 index 000000000..c9075cbbc --- /dev/null +++ b/twml/libtwml/src/lib/internal/khash.h @@ -0,0 +1,627 @@ +/* The MIT License + + Copyright (c) 2008, 2009, 2011 by Attractive Chaos + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. +*/ + +/* + An example: + +#include "khash.h" +KHASH_MAP_INIT_INT(32, char) +int main() { + int ret, is_missing; + khiter_t k; + khash_t(32) *h = kh_init(32); + k = kh_put(32, h, 5, &ret); + kh_value(h, k) = 10; + k = kh_get(32, h, 10); + is_missing = (k == kh_end(h)); + k = kh_get(32, h, 5); + kh_del(32, h, k); + for (k = kh_begin(h); k != kh_end(h); ++k) + if (kh_exist(h, k)) kh_value(h, k) = 1; + kh_destroy(32, h); + return 0; +} +*/ + +/* + 2013-05-02 (0.2.8): + + * Use quadratic probing. When the capacity is power of 2, stepping function + i*(i+1)/2 guarantees to traverse each bucket. It is better than double + hashing on cache performance and is more robust than linear probing. + + In theory, double hashing should be more robust than quadratic probing. + However, my implementation is probably not for large hash tables, because + the second hash function is closely tied to the first hash function, + which reduce the effectiveness of double hashing. + + Reference: http://research.cs.vt.edu/AVresearch/hashing/quadratic.php + + 2011-12-29 (0.2.7): + + * Minor code clean up; no actual effect. + + 2011-09-16 (0.2.6): + + * The capacity is a power of 2. This seems to dramatically improve the + speed for simple keys. Thank Zilong Tan for the suggestion. Reference: + + - http://code.google.com/p/ulib/ + - http://nothings.org/computer/judy/ + + * Allow to optionally use linear probing which usually has better + performance for random input. Double hashing is still the default as it + is more robust to certain non-random input. + + * Added Wang's integer hash function (not used by default). This hash + function is more robust to certain non-random input. + + 2011-02-14 (0.2.5): + + * Allow to declare global functions. + + 2009-09-26 (0.2.4): + + * Improve portability + + 2008-09-19 (0.2.3): + + * Corrected the example + * Improved interfaces + + 2008-09-11 (0.2.2): + + * Improved speed a little in kh_put() + + 2008-09-10 (0.2.1): + + * Added kh_clear() + * Fixed a compiling error + + 2008-09-02 (0.2.0): + + * Changed to token concatenation which increases flexibility. + + 2008-08-31 (0.1.2): + + * Fixed a bug in kh_get(), which has not been tested previously. + + 2008-08-31 (0.1.1): + + * Added destructor +*/ + + +#ifndef __AC_KHASH_H +#define __AC_KHASH_H + +/*! + @header + + Generic hash table library. + */ + +#define AC_VERSION_KHASH_H "0.2.8" + +#include +#include +#include + +/* compiler specific configuration */ + +#if UINT_MAX == 0xffffffffu +typedef unsigned int khint32_t; +#elif ULONG_MAX == 0xffffffffu +typedef unsigned long khint32_t; +#endif + +#if ULONG_MAX == ULLONG_MAX +typedef unsigned long khint64_t; +#else +typedef uint64_t khint64_t; +#endif + +#ifndef kh_inline +#ifdef _MSC_VER +#define kh_inline __inline +#else +#define kh_inline inline +#endif +#endif /* kh_inline */ + +#ifndef klib_unused +#if (defined __clang__ && __clang_major__ >= 3) || (defined __GNUC__ && __GNUC__ >= 3) +#define klib_unused __attribute__ ((__unused__)) +#else +#define klib_unused +#endif +#endif /* klib_unused */ + +typedef khint32_t khint_t; +typedef khint_t khiter_t; + +#define __ac_isempty(flag, i) ((flag[i>>4]>>((i&0xfU)<<1))&2) +#define __ac_isdel(flag, i) ((flag[i>>4]>>((i&0xfU)<<1))&1) +#define __ac_iseither(flag, i) ((flag[i>>4]>>((i&0xfU)<<1))&3) +#define __ac_set_isdel_false(flag, i) (flag[i>>4]&=~(1ul<<((i&0xfU)<<1))) +#define __ac_set_isempty_false(flag, i) (flag[i>>4]&=~(2ul<<((i&0xfU)<<1))) +#define __ac_set_isboth_false(flag, i) (flag[i>>4]&=~(3ul<<((i&0xfU)<<1))) +#define __ac_set_isdel_true(flag, i) (flag[i>>4]|=1ul<<((i&0xfU)<<1)) + +#define __ac_fsize(m) ((m) < 16? 1 : (m)>>4) + +#ifndef kroundup32 +#define kroundup32(x) (--(x), (x)|=(x)>>1, (x)|=(x)>>2, (x)|=(x)>>4, (x)|=(x)>>8, (x)|=(x)>>16, ++(x)) +#endif + +#ifndef kcalloc +#define kcalloc(N,Z) calloc(N,Z) +#endif +#ifndef kmalloc +#define kmalloc(Z) malloc(Z) +#endif +#ifndef krealloc +#define krealloc(P,Z) realloc(P,Z) +#endif +#ifndef kfree +#define kfree(P) free(P) +#endif + +static const double __ac_HASH_UPPER = 0.77; + +#define __KHASH_TYPE(name, khkey_t, khval_t) \ + typedef struct kh_##name##_s { \ + khint_t n_buckets, size, n_occupied, upper_bound; \ + khint32_t *flags; \ + khkey_t *keys; \ + khval_t *vals; \ + } kh_##name##_t; + +#define __KHASH_PROTOTYPES(name, khkey_t, khval_t) \ + extern kh_##name##_t *kh_init_##name(void); \ + extern void kh_destroy_##name(kh_##name##_t *h); \ + extern void kh_clear_##name(kh_##name##_t *h); \ + extern khint_t kh_get_##name(const kh_##name##_t *h, khkey_t key); \ + extern int kh_resize_##name(kh_##name##_t *h, khint_t new_n_buckets); \ + extern khint_t kh_put_##name(kh_##name##_t *h, khkey_t key, int *ret); \ + extern void kh_del_##name(kh_##name##_t *h, khint_t x); + +#define __KHASH_IMPL(name, SCOPE, khkey_t, khval_t, kh_is_map, __hash_func, __hash_equal) \ + SCOPE kh_##name##_t *kh_init_##name(void) { \ + return (kh_##name##_t*)kcalloc(1, sizeof(kh_##name##_t)); \ + } \ + SCOPE void kh_destroy_##name(kh_##name##_t *h) \ + { \ + if (h) { \ + kfree((void *)h->keys); kfree(h->flags); \ + kfree((void *)h->vals); \ + kfree(h); \ + } \ + } \ + SCOPE void kh_clear_##name(kh_##name##_t *h) \ + { \ + if (h && h->flags) { \ + memset(h->flags, 0xaa, __ac_fsize(h->n_buckets) * sizeof(khint32_t)); \ + h->size = h->n_occupied = 0; \ + } \ + } \ + SCOPE khint_t kh_get_##name(const kh_##name##_t *h, khkey_t key) \ + { \ + if (h->n_buckets) { \ + khint_t k, i, last, mask, step = 0; \ + mask = h->n_buckets - 1; \ + k = __hash_func(key); i = k & mask; \ + last = i; \ + while (!__ac_isempty(h->flags, i) && (__ac_isdel(h->flags, i) || !__hash_equal(h->keys[i], key))) { \ + i = (i + (++step)) & mask; \ + if (i == last) return h->n_buckets; \ + } \ + return __ac_iseither(h->flags, i)? h->n_buckets : i; \ + } else return 0; \ + } \ + SCOPE int kh_resize_##name(kh_##name##_t *h, khint_t new_n_buckets) \ + { /* This function uses 0.25*n_buckets bytes of working space instead of [sizeof(key_t+val_t)+.25]*n_buckets. */ \ + khint32_t *new_flags = 0; \ + khint_t j = 1; \ + { \ + kroundup32(new_n_buckets); \ + if (new_n_buckets < 4) new_n_buckets = 4; \ + if (h->size >= (khint_t)(new_n_buckets * __ac_HASH_UPPER + 0.5)) j = 0; /* requested size is too small */ \ + else { /* hash table size to be changed (shrink or expand); rehash */ \ + new_flags = (khint32_t*)kmalloc(__ac_fsize(new_n_buckets) * sizeof(khint32_t)); \ + if (!new_flags) return -1; \ + memset(new_flags, 0xaa, __ac_fsize(new_n_buckets) * sizeof(khint32_t)); \ + if (h->n_buckets < new_n_buckets) { /* expand */ \ + khkey_t *new_keys = (khkey_t*)krealloc((void *)h->keys, new_n_buckets * sizeof(khkey_t)); \ + if (!new_keys) { kfree(new_flags); return -1; } \ + h->keys = new_keys; \ + if (kh_is_map) { \ + khval_t *new_vals = (khval_t*)krealloc((void *)h->vals, new_n_buckets * sizeof(khval_t)); \ + if (!new_vals) { kfree(new_flags); return -1; } \ + h->vals = new_vals; \ + } \ + } /* otherwise shrink */ \ + } \ + } \ + if (j) { /* rehashing is needed */ \ + for (j = 0; j != h->n_buckets; ++j) { \ + if (__ac_iseither(h->flags, j) == 0) { \ + khkey_t key = h->keys[j]; \ + khval_t val; \ + khint_t new_mask; \ + new_mask = new_n_buckets - 1; \ + if (kh_is_map) val = h->vals[j]; \ + __ac_set_isdel_true(h->flags, j); \ + while (1) { /* kick-out process; sort of like in Cuckoo hashing */ \ + khint_t k, i, step = 0; \ + k = __hash_func(key); \ + i = k & new_mask; \ + while (!__ac_isempty(new_flags, i)) i = (i + (++step)) & new_mask; \ + __ac_set_isempty_false(new_flags, i); \ + if (i < h->n_buckets && __ac_iseither(h->flags, i) == 0) { /* kick out the existing element */ \ + { khkey_t tmp = h->keys[i]; h->keys[i] = key; key = tmp; } \ + if (kh_is_map) { khval_t tmp = h->vals[i]; h->vals[i] = val; val = tmp; } \ + __ac_set_isdel_true(h->flags, i); /* mark it as deleted in the old hash table */ \ + } else { /* write the element and jump out of the loop */ \ + h->keys[i] = key; \ + if (kh_is_map) h->vals[i] = val; \ + break; \ + } \ + } \ + } \ + } \ + if (h->n_buckets > new_n_buckets) { /* shrink the hash table */ \ + h->keys = (khkey_t*)krealloc((void *)h->keys, new_n_buckets * sizeof(khkey_t)); \ + if (kh_is_map) h->vals = (khval_t*)krealloc((void *)h->vals, new_n_buckets * sizeof(khval_t)); \ + } \ + kfree(h->flags); /* free the working space */ \ + h->flags = new_flags; \ + h->n_buckets = new_n_buckets; \ + h->n_occupied = h->size; \ + h->upper_bound = (khint_t)(h->n_buckets * __ac_HASH_UPPER + 0.5); \ + } \ + return 0; \ + } \ + SCOPE khint_t kh_put_##name(kh_##name##_t *h, khkey_t key, int *ret) \ + { \ + khint_t x; \ + if (h->n_occupied >= h->upper_bound) { /* update the hash table */ \ + if (h->n_buckets > (h->size<<1)) { \ + if (kh_resize_##name(h, h->n_buckets - 1) < 0) { /* clear "deleted" elements */ \ + *ret = -1; return h->n_buckets; \ + } \ + } else if (kh_resize_##name(h, h->n_buckets + 1) < 0) { /* expand the hash table */ \ + *ret = -1; return h->n_buckets; \ + } \ + } /* TODO: to implement automatically shrinking; resize() already support shrinking */ \ + { \ + khint_t k, i, site, last, mask = h->n_buckets - 1, step = 0; \ + x = site = h->n_buckets; k = __hash_func(key); i = k & mask; \ + if (__ac_isempty(h->flags, i)) x = i; /* for speed up */ \ + else { \ + last = i; \ + while (!__ac_isempty(h->flags, i) && (__ac_isdel(h->flags, i) || !__hash_equal(h->keys[i], key))) { \ + if (__ac_isdel(h->flags, i)) site = i; \ + i = (i + (++step)) & mask; \ + if (i == last) { x = site; break; } \ + } \ + if (x == h->n_buckets) { \ + if (__ac_isempty(h->flags, i) && site != h->n_buckets) x = site; \ + else x = i; \ + } \ + } \ + } \ + if (__ac_isempty(h->flags, x)) { /* not present at all */ \ + h->keys[x] = key; \ + __ac_set_isboth_false(h->flags, x); \ + ++h->size; ++h->n_occupied; \ + *ret = 1; \ + } else if (__ac_isdel(h->flags, x)) { /* deleted */ \ + h->keys[x] = key; \ + __ac_set_isboth_false(h->flags, x); \ + ++h->size; \ + *ret = 2; \ + } else *ret = 0; /* Don't touch h->keys[x] if present and not deleted */ \ + return x; \ + } \ + SCOPE void kh_del_##name(kh_##name##_t *h, khint_t x) \ + { \ + if (x != h->n_buckets && !__ac_iseither(h->flags, x)) { \ + __ac_set_isdel_true(h->flags, x); \ + --h->size; \ + } \ + } + +#define KHASH_DECLARE(name, khkey_t, khval_t) \ + __KHASH_TYPE(name, khkey_t, khval_t) \ + __KHASH_PROTOTYPES(name, khkey_t, khval_t) + +#define KHASH_INIT2(name, SCOPE, khkey_t, khval_t, kh_is_map, __hash_func, __hash_equal) \ + __KHASH_TYPE(name, khkey_t, khval_t) \ + __KHASH_IMPL(name, SCOPE, khkey_t, khval_t, kh_is_map, __hash_func, __hash_equal) + +#define KHASH_INIT(name, khkey_t, khval_t, kh_is_map, __hash_func, __hash_equal) \ + KHASH_INIT2(name, static kh_inline klib_unused, khkey_t, khval_t, kh_is_map, __hash_func, __hash_equal) + +/* --- BEGIN OF HASH FUNCTIONS --- */ + +/*! @function + @abstract Integer hash function + @param key The integer [khint32_t] + @return The hash value [khint_t] + */ +#define kh_int_hash_func(key) (khint32_t)(key) +/*! @function + @abstract Integer comparison function + */ +#define kh_int_hash_equal(a, b) ((a) == (b)) +/*! @function + @abstract 64-bit integer hash function + @param key The integer [khint64_t] + @return The hash value [khint_t] + */ +#define kh_int64_hash_func(key) (khint32_t)((key)>>33^(key)^(key)<<11) +/*! @function + @abstract 64-bit integer comparison function + */ +#define kh_int64_hash_equal(a, b) ((a) == (b)) +/*! @function + @abstract const char* hash function + @param s Pointer to a null terminated string + @return The hash value + */ +static kh_inline khint_t __ac_X31_hash_string(const char *s) +{ + khint_t h = (khint_t)*s; + if (h) for (++s ; *s; ++s) h = (h << 5) - h + (khint_t)*s; + return h; +} +/*! @function + @abstract Another interface to const char* hash function + @param key Pointer to a null terminated string [const char*] + @return The hash value [khint_t] + */ +#define kh_str_hash_func(key) __ac_X31_hash_string(key) +/*! @function + @abstract Const char* comparison function + */ +#define kh_str_hash_equal(a, b) (strcmp(a, b) == 0) + +static kh_inline khint_t __ac_Wang_hash(khint_t key) +{ + key += ~(key << 15); + key ^= (key >> 10); + key += (key << 3); + key ^= (key >> 6); + key += ~(key << 11); + key ^= (key >> 16); + return key; +} +#define kh_int_hash_func2(key) __ac_Wang_hash((khint_t)key) + +/* --- END OF HASH FUNCTIONS --- */ + +/* Other convenient macros... */ + +/*! + @abstract Type of the hash table. + @param name Name of the hash table [symbol] + */ +#define khash_t(name) kh_##name##_t + +/*! @function + @abstract Initiate a hash table. + @param name Name of the hash table [symbol] + @return Pointer to the hash table [khash_t(name)*] + */ +#define kh_init(name) kh_init_##name() + +/*! @function + @abstract Destroy a hash table. + @param name Name of the hash table [symbol] + @param h Pointer to the hash table [khash_t(name)*] + */ +#define kh_destroy(name, h) kh_destroy_##name(h) + +/*! @function + @abstract Reset a hash table without deallocating memory. + @param name Name of the hash table [symbol] + @param h Pointer to the hash table [khash_t(name)*] + */ +#define kh_clear(name, h) kh_clear_##name(h) + +/*! @function + @abstract Resize a hash table. + @param name Name of the hash table [symbol] + @param h Pointer to the hash table [khash_t(name)*] + @param s New size [khint_t] + */ +#define kh_resize(name, h, s) kh_resize_##name(h, s) + +/*! @function + @abstract Insert a key to the hash table. + @param name Name of the hash table [symbol] + @param h Pointer to the hash table [khash_t(name)*] + @param k Key [type of keys] + @param r Extra return code: -1 if the operation failed; + 0 if the key is present in the hash table; + 1 if the bucket is empty (never used); 2 if the element in + the bucket has been deleted [int*] + @return Iterator to the inserted element [khint_t] + */ +#define kh_put(name, h, k, r) kh_put_##name(h, k, r) + +/*! @function + @abstract Retrieve a key from the hash table. + @param name Name of the hash table [symbol] + @param h Pointer to the hash table [khash_t(name)*] + @param k Key [type of keys] + @return Iterator to the found element, or kh_end(h) if the element is absent [khint_t] + */ +#define kh_get(name, h, k) kh_get_##name(h, k) + +/*! @function + @abstract Remove a key from the hash table. + @param name Name of the hash table [symbol] + @param h Pointer to the hash table [khash_t(name)*] + @param k Iterator to the element to be deleted [khint_t] + */ +#define kh_del(name, h, k) kh_del_##name(h, k) + +/*! @function + @abstract Test whether a bucket contains data. + @param h Pointer to the hash table [khash_t(name)*] + @param x Iterator to the bucket [khint_t] + @return 1 if containing data; 0 otherwise [int] + */ +#define kh_exist(h, x) (!__ac_iseither((h)->flags, (x))) + +/*! @function + @abstract Get key given an iterator + @param h Pointer to the hash table [khash_t(name)*] + @param x Iterator to the bucket [khint_t] + @return Key [type of keys] + */ +#define kh_key(h, x) ((h)->keys[x]) + +/*! @function + @abstract Get value given an iterator + @param h Pointer to the hash table [khash_t(name)*] + @param x Iterator to the bucket [khint_t] + @return Value [type of values] + @discussion For hash sets, calling this results in segfault. + */ +#define kh_val(h, x) ((h)->vals[x]) + +/*! @function + @abstract Alias of kh_val() + */ +#define kh_value(h, x) ((h)->vals[x]) + +/*! @function + @abstract Get the start iterator + @param h Pointer to the hash table [khash_t(name)*] + @return The start iterator [khint_t] + */ +#define kh_begin(h) (khint_t)(0) + +/*! @function + @abstract Get the end iterator + @param h Pointer to the hash table [khash_t(name)*] + @return The end iterator [khint_t] + */ +#define kh_end(h) ((h)->n_buckets) + +/*! @function + @abstract Get the number of elements in the hash table + @param h Pointer to the hash table [khash_t(name)*] + @return Number of elements in the hash table [khint_t] + */ +#define kh_size(h) ((h)->size) + +/*! @function + @abstract Get the number of buckets in the hash table + @param h Pointer to the hash table [khash_t(name)*] + @return Number of buckets in the hash table [khint_t] + */ +#define kh_n_buckets(h) ((h)->n_buckets) + +/*! @function + @abstract Iterate over the entries in the hash table + @param h Pointer to the hash table [khash_t(name)*] + @param kvar Variable to which key will be assigned + @param vvar Variable to which value will be assigned + @param code Block of code to execute + */ +#define kh_foreach(h, kvar, vvar, code) { khint_t __i; \ + for (__i = kh_begin(h); __i != kh_end(h); ++__i) { \ + if (!kh_exist(h,__i)) continue; \ + (kvar) = kh_key(h,__i); \ + (vvar) = kh_val(h,__i); \ + code; \ + } } + +/*! @function + @abstract Iterate over the values in the hash table + @param h Pointer to the hash table [khash_t(name)*] + @param vvar Variable to which value will be assigned + @param code Block of code to execute + */ +#define kh_foreach_value(h, vvar, code) { khint_t __i; \ + for (__i = kh_begin(h); __i != kh_end(h); ++__i) { \ + if (!kh_exist(h,__i)) continue; \ + (vvar) = kh_val(h,__i); \ + code; \ + } } + +/* More conenient interfaces */ + +/*! @function + @abstract Instantiate a hash set containing integer keys + @param name Name of the hash table [symbol] + */ +#define KHASH_SET_INIT_INT(name) \ + KHASH_INIT(name, khint32_t, char, 0, kh_int_hash_func, kh_int_hash_equal) + +/*! @function + @abstract Instantiate a hash map containing integer keys + @param name Name of the hash table [symbol] + @param khval_t Type of values [type] + */ +#define KHASH_MAP_INIT_INT(name, khval_t) \ + KHASH_INIT(name, khint32_t, khval_t, 1, kh_int_hash_func, kh_int_hash_equal) + +/*! @function + @abstract Instantiate a hash map containing 64-bit integer keys + @param name Name of the hash table [symbol] + */ +#define KHASH_SET_INIT_INT64(name) \ + KHASH_INIT(name, khint64_t, char, 0, kh_int64_hash_func, kh_int64_hash_equal) + +/*! @function + @abstract Instantiate a hash map containing 64-bit integer keys + @param name Name of the hash table [symbol] + @param khval_t Type of values [type] + */ +#define KHASH_MAP_INIT_INT64(name, khval_t) \ + KHASH_INIT(name, khint64_t, khval_t, 1, kh_int64_hash_func, kh_int64_hash_equal) + +typedef const char *kh_cstr_t; +/*! @function + @abstract Instantiate a hash map containing const char* keys + @param name Name of the hash table [symbol] + */ +#define KHASH_SET_INIT_STR(name) \ + KHASH_INIT(name, kh_cstr_t, char, 0, kh_str_hash_func, kh_str_hash_equal) + +/*! @function + @abstract Instantiate a hash map containing const char* keys + @param name Name of the hash table [symbol] + @param khval_t Type of values [type] + */ +#define KHASH_MAP_INIT_STR(name, khval_t) \ + KHASH_INIT(name, kh_cstr_t, khval_t, 1, kh_str_hash_func, kh_str_hash_equal) + +#endif /* __AC_KHASH_H */ diff --git a/twml/libtwml/src/lib/internal/linear_search.h b/twml/libtwml/src/lib/internal/linear_search.h new file mode 100644 index 000000000..a3d294853 --- /dev/null +++ b/twml/libtwml/src/lib/internal/linear_search.h @@ -0,0 +1,17 @@ +#pragma once + +#ifdef __cplusplus +#include +namespace twml { + + template + static int64_t linear_search(const Tx *xsData, const Tx val, const int64_t mainSize) { + int64_t left = 0; + int64_t right = mainSize-1; + while(left <= right && val > xsData[left]) + left++; + return left; + } + +} // namespace twml +#endif diff --git a/twml/libtwml/src/lib/internal/murmur_hash3.h b/twml/libtwml/src/lib/internal/murmur_hash3.h new file mode 100644 index 000000000..3bdfbe486 --- /dev/null +++ b/twml/libtwml/src/lib/internal/murmur_hash3.h @@ -0,0 +1,37 @@ +//----------------------------------------------------------------------------- +// MurmurHash3 was written by Austin Appleby, and is placed in the public +// domain. The author hereby disclaims copyright to this source code. + +#ifndef _MURMURHASH3_H_ +#define _MURMURHASH3_H_ + +//----------------------------------------------------------------------------- +// Platform-specific functions and macros + +// Microsoft Visual Studio + +#if defined(_MSC_VER) && (_MSC_VER < 1600) + +typedef unsigned char uint8_t; +typedef unsigned int uint32_t; +typedef unsigned __int64 uint64_t; + +// Other compilers + +#else // defined(_MSC_VER) + +#include + +#endif // !defined(_MSC_VER) + +//----------------------------------------------------------------------------- + +void MurmurHash3_x86_32 ( const void * key, int len, uint32_t seed, void * out ); + +void MurmurHash3_x86_128 ( const void * key, int len, uint32_t seed, void * out ); + +void MurmurHash3_x64_128 ( const void * key, int len, uint32_t seed, void * out ); + +//----------------------------------------------------------------------------- + +#endif // _MURMURHASH3_H_ diff --git a/twml/libtwml/src/lib/internal/thrift.h b/twml/libtwml/src/lib/internal/thrift.h new file mode 100644 index 000000000..4e4786219 --- /dev/null +++ b/twml/libtwml/src/lib/internal/thrift.h @@ -0,0 +1,69 @@ +// For details of how to encode and decode thrift, check +// https://github.com/apache/thrift/blob/master/doc/specs/thrift-binary-protocol.md + +// Definitions of the thrift binary format +typedef enum { + TTYPE_STOP = 0, + TTYPE_VOID = 1, + TTYPE_BOOL = 2, + TTYPE_BYTE = 3, + TTYPE_DOUBLE = 4, + TTYPE_I16 = 6, + TTYPE_I32 = 8, + TTYPE_I64 = 10, + TTYPE_STRING = 11, + TTYPE_STRUCT = 12, + TTYPE_MAP = 13, + TTYPE_SET = 14, + TTYPE_LIST = 15, + TTYPE_ENUM = 16, +} TTYPES; + +// Fields of a batch prediction response +typedef enum { + BPR_DUMMY , + BPR_PREDICTIONS, +} BPR_FIELDS; + +// Fields of a datarecord +typedef enum { + DR_CROSS , // fake field for crosses + DR_BINARY , + DR_CONTINUOUS , + DR_DISCRETE , + DR_STRING , + DR_SPARSE_BINARY , + DR_SPARSE_CONTINUOUS , + DR_BLOB , + DR_GENERAL_TENSOR , + DR_SPARSE_TENSOR , +} DR_FIELDS; + +// Fields for General tensor +typedef enum { + GT_DUMMY , // dummy field + GT_RAW , + GT_STRING , + GT_INT32 , + GT_INT64 , + GT_FLOAT , + GT_DOUBLE , + GT_BOOL , +} GT_FIELDS; + +typedef enum { + SP_DUMMY , // dummy field + SP_COO , +} SP_FIELDS; + +// Enum values from tensor.thrift +typedef enum { + DATA_TYPE_FLOAT , + DATA_TYPE_DOUBLE , + DATA_TYPE_INT32 , + DATA_TYPE_INT64 , + DATA_TYPE_UINT8 , + DATA_TYPE_STRING , + DATA_TYPE_BYTE , + DATA_TYPE_BOOL , +} DATA_TYPES; diff --git a/twml/libtwml/src/lib/internal/utf_converter.h b/twml/libtwml/src/lib/internal/utf_converter.h new file mode 100644 index 000000000..b0b38fb11 --- /dev/null +++ b/twml/libtwml/src/lib/internal/utf_converter.h @@ -0,0 +1,10 @@ +#ifndef _UTF_CONVERTER_H_ +#define _UTF_CONVERTER_H_ + +#include +#include +#include + +ssize_t utf8_to_utf16(const uint8_t *in, uint64_t in_len, uint16_t *out, uint64_t max_out); + +#endif diff --git a/twml/libtwml/src/lib/io/IOError.cpp b/twml/libtwml/src/lib/io/IOError.cpp new file mode 100644 index 000000000..e0a661c13 --- /dev/null +++ b/twml/libtwml/src/lib/io/IOError.cpp @@ -0,0 +1,61 @@ +#include + + +namespace twml { +namespace io { + +namespace { + std::string messageFromStatus(IOError::Status status) { + switch (status) { + case IOError::OUT_OF_RANGE: + return "failed to read enough input"; + case IOError::WRONG_MAGIC: + return "wrong magic in stream"; + case IOError::WRONG_HEADER: + return "wrong header in stream"; + case IOError::ERROR_HEADER_CHECKSUM: + return "header checksum doesn't match"; + case IOError::INVALID_METHOD: + return "using invalid method"; + case IOError::USING_RESERVED: + return "using reserved flag"; + case IOError::ERROR_HEADER_EXTRA_FIELD_CHECKSUM: + return "extra header field checksum doesn't match"; + case IOError::CANT_FIT_OUTPUT: + return "can't fit output in the given space"; + case IOError::SPLIT_FILE: + return "split files aren't supported"; + case IOError::BLOCK_SIZE_TOO_LARGE: + return "block size is too large"; + case IOError::SOURCE_LARGER_THAN_DESTINATION: + return "source is larger than destination"; + case IOError::DESTINATION_LARGER_THAN_CAPACITY: + return "destination buffer is too small to fit uncompressed result"; + case IOError::HEADER_FLAG_MISMATCH: + return "failed to match flags for compressed and decompressed data"; + case IOError::NOT_ENOUGH_INPUT: + return "not enough input to proceed with decompression"; + case IOError::ERROR_SOURCE_BLOCK_CHECKSUM: + return "source block checksum doesn't match"; + case IOError::COMPRESSED_DATA_VIOLATION: + return "error occurred while decompressing the data"; + case IOError::ERROR_DESTINATION_BLOCK_CHECKSUM: + return "destination block checksum doesn't match"; + case IOError::EMPTY_RECORD: + return "can't write an empty record"; + case IOError::MALFORMED_MEMORY_RECORD: + return "can't write malformed record"; + case IOError::UNSUPPORTED_OUTPUT_TYPE: + return "output data type is not supported"; + case IOError::OTHER_ERROR: + default: + return "unknown error occurred"; + } + } +} // namespace + +IOError::IOError(Status status): twml::Error(TWML_ERR_IO, "Found error while processing stream: " + + messageFromStatus(status)), m_status(status) {} + +} // namespace io +} // namespace twml diff --git a/twml/libtwml/src/lib/murmur_hash3.cpp b/twml/libtwml/src/lib/murmur_hash3.cpp new file mode 100644 index 000000000..89c9c1fc1 --- /dev/null +++ b/twml/libtwml/src/lib/murmur_hash3.cpp @@ -0,0 +1,335 @@ +//----------------------------------------------------------------------------- +// MurmurHash3 was written by Austin Appleby, and is placed in the public +// domain. The author hereby disclaims copyright to this source code. + +// Note - The x86 and x64 versions do _not_ produce the same results, as the +// algorithms are optimized for their respective platforms. You can still +// compile and run any of them on any platform, but your performance with the +// non-native version will be less than optimal. + +#include "internal/murmur_hash3.h" + +//----------------------------------------------------------------------------- +// Platform-specific functions and macros + +// Microsoft Visual Studio + +#if defined(_MSC_VER) + +#define FORCE_INLINE __forceinline + +#include + +#define ROTL32(x,y) _rotl(x,y) +#define ROTL64(x,y) _rotl64(x,y) + +#define BIG_CONSTANT(x) (x) + +// Other compilers + +#else // defined(_MSC_VER) + +#define FORCE_INLINE inline __attribute__((always_inline)) + +FORCE_INLINE uint32_t rotl32 ( uint32_t x, int8_t r ) +{ + return (x << r) | (x >> (32 - r)); +} + +FORCE_INLINE uint64_t rotl64 ( uint64_t x, int8_t r ) +{ + return (x << r) | (x >> (64 - r)); +} + +#define ROTL32(x,y) rotl32(x,y) +#define ROTL64(x,y) rotl64(x,y) + +#define BIG_CONSTANT(x) (x##LLU) + +#endif // !defined(_MSC_VER) + +//----------------------------------------------------------------------------- +// Block read - if your platform needs to do endian-swapping or can only +// handle aligned reads, do the conversion here + +FORCE_INLINE uint32_t getblock32 ( const uint32_t * p, int i ) +{ + return p[i]; +} + +FORCE_INLINE uint64_t getblock64 ( const uint64_t * p, int i ) +{ + return p[i]; +} + +//----------------------------------------------------------------------------- +// Finalization mix - force all bits of a hash block to avalanche + +FORCE_INLINE uint32_t fmix32 ( uint32_t h ) +{ + h ^= h >> 16; + h *= 0x85ebca6b; + h ^= h >> 13; + h *= 0xc2b2ae35; + h ^= h >> 16; + + return h; +} + +//---------- + +FORCE_INLINE uint64_t fmix64 ( uint64_t k ) +{ + k ^= k >> 33; + k *= BIG_CONSTANT(0xff51afd7ed558ccd); + k ^= k >> 33; + k *= BIG_CONSTANT(0xc4ceb9fe1a85ec53); + k ^= k >> 33; + + return k; +} + +//----------------------------------------------------------------------------- + +void MurmurHash3_x86_32 ( const void * key, int len, + uint32_t seed, void * out ) +{ + const uint8_t * data = (const uint8_t*)key; + const int nblocks = len / 4; + + uint32_t h1 = seed; + + const uint32_t c1 = 0xcc9e2d51; + const uint32_t c2 = 0x1b873593; + + //---------- + // body + + const uint32_t * blocks = (const uint32_t *)(data + nblocks*4); + + for(int i = -nblocks; i; i++) + { + uint32_t k1 = getblock32(blocks,i); + + k1 *= c1; + k1 = ROTL32(k1,15); + k1 *= c2; + + h1 ^= k1; + h1 = ROTL32(h1,13); + h1 = h1*5+0xe6546b64; + } + + //---------- + // tail + + const uint8_t * tail = (const uint8_t*)(data + nblocks*4); + + uint32_t k1 = 0; + + switch(len & 3) + { + case 3: k1 ^= tail[2] << 16; + case 2: k1 ^= tail[1] << 8; + case 1: k1 ^= tail[0]; + k1 *= c1; k1 = ROTL32(k1,15); k1 *= c2; h1 ^= k1; + }; + + //---------- + // finalization + + h1 ^= len; + + h1 = fmix32(h1); + + *(uint32_t*)out = h1; +} + +//----------------------------------------------------------------------------- + +void MurmurHash3_x86_128 ( const void * key, const int len, + uint32_t seed, void * out ) +{ + const uint8_t * data = (const uint8_t*)key; + const int nblocks = len / 16; + + uint32_t h1 = seed; + uint32_t h2 = seed; + uint32_t h3 = seed; + uint32_t h4 = seed; + + const uint32_t c1 = 0x239b961b; + const uint32_t c2 = 0xab0e9789; + const uint32_t c3 = 0x38b34ae5; + const uint32_t c4 = 0xa1e38b93; + + //---------- + // body + + const uint32_t * blocks = (const uint32_t *)(data + nblocks*16); + + for(int i = -nblocks; i; i++) + { + uint32_t k1 = getblock32(blocks,i*4+0); + uint32_t k2 = getblock32(blocks,i*4+1); + uint32_t k3 = getblock32(blocks,i*4+2); + uint32_t k4 = getblock32(blocks,i*4+3); + + k1 *= c1; k1 = ROTL32(k1,15); k1 *= c2; h1 ^= k1; + + h1 = ROTL32(h1,19); h1 += h2; h1 = h1*5+0x561ccd1b; + + k2 *= c2; k2 = ROTL32(k2,16); k2 *= c3; h2 ^= k2; + + h2 = ROTL32(h2,17); h2 += h3; h2 = h2*5+0x0bcaa747; + + k3 *= c3; k3 = ROTL32(k3,17); k3 *= c4; h3 ^= k3; + + h3 = ROTL32(h3,15); h3 += h4; h3 = h3*5+0x96cd1c35; + + k4 *= c4; k4 = ROTL32(k4,18); k4 *= c1; h4 ^= k4; + + h4 = ROTL32(h4,13); h4 += h1; h4 = h4*5+0x32ac3b17; + } + + //---------- + // tail + + const uint8_t * tail = (const uint8_t*)(data + nblocks*16); + + uint32_t k1 = 0; + uint32_t k2 = 0; + uint32_t k3 = 0; + uint32_t k4 = 0; + + switch(len & 15) + { + case 15: k4 ^= tail[14] << 16; + case 14: k4 ^= tail[13] << 8; + case 13: k4 ^= tail[12] << 0; + k4 *= c4; k4 = ROTL32(k4,18); k4 *= c1; h4 ^= k4; + + case 12: k3 ^= tail[11] << 24; + case 11: k3 ^= tail[10] << 16; + case 10: k3 ^= tail[ 9] << 8; + case 9: k3 ^= tail[ 8] << 0; + k3 *= c3; k3 = ROTL32(k3,17); k3 *= c4; h3 ^= k3; + + case 8: k2 ^= tail[ 7] << 24; + case 7: k2 ^= tail[ 6] << 16; + case 6: k2 ^= tail[ 5] << 8; + case 5: k2 ^= tail[ 4] << 0; + k2 *= c2; k2 = ROTL32(k2,16); k2 *= c3; h2 ^= k2; + + case 4: k1 ^= tail[ 3] << 24; + case 3: k1 ^= tail[ 2] << 16; + case 2: k1 ^= tail[ 1] << 8; + case 1: k1 ^= tail[ 0] << 0; + k1 *= c1; k1 = ROTL32(k1,15); k1 *= c2; h1 ^= k1; + }; + + //---------- + // finalization + + h1 ^= len; h2 ^= len; h3 ^= len; h4 ^= len; + + h1 += h2; h1 += h3; h1 += h4; + h2 += h1; h3 += h1; h4 += h1; + + h1 = fmix32(h1); + h2 = fmix32(h2); + h3 = fmix32(h3); + h4 = fmix32(h4); + + h1 += h2; h1 += h3; h1 += h4; + h2 += h1; h3 += h1; h4 += h1; + + ((uint32_t*)out)[0] = h1; + ((uint32_t*)out)[1] = h2; + ((uint32_t*)out)[2] = h3; + ((uint32_t*)out)[3] = h4; +} + +//----------------------------------------------------------------------------- + +void MurmurHash3_x64_128 ( const void * key, const int len, + const uint32_t seed, void * out ) +{ + const uint8_t * data = (const uint8_t*)key; + const int nblocks = len / 16; + + uint64_t h1 = seed; + uint64_t h2 = seed; + + const uint64_t c1 = BIG_CONSTANT(0x87c37b91114253d5); + const uint64_t c2 = BIG_CONSTANT(0x4cf5ad432745937f); + + //---------- + // body + + const uint64_t * blocks = (const uint64_t *)(data); + + for(int i = 0; i < nblocks; i++) + { + uint64_t k1 = getblock64(blocks,i*2+0); + uint64_t k2 = getblock64(blocks,i*2+1); + + k1 *= c1; k1 = ROTL64(k1,31); k1 *= c2; h1 ^= k1; + + h1 = ROTL64(h1,27); h1 += h2; h1 = h1*5+0x52dce729; + + k2 *= c2; k2 = ROTL64(k2,33); k2 *= c1; h2 ^= k2; + + h2 = ROTL64(h2,31); h2 += h1; h2 = h2*5+0x38495ab5; + } + + //---------- + // tail + + const uint8_t * tail = (const uint8_t*)(data + nblocks*16); + + uint64_t k1 = 0; + uint64_t k2 = 0; + + switch(len & 15) + { + case 15: k2 ^= ((uint64_t)tail[14]) << 48; + case 14: k2 ^= ((uint64_t)tail[13]) << 40; + case 13: k2 ^= ((uint64_t)tail[12]) << 32; + case 12: k2 ^= ((uint64_t)tail[11]) << 24; + case 11: k2 ^= ((uint64_t)tail[10]) << 16; + case 10: k2 ^= ((uint64_t)tail[ 9]) << 8; + case 9: k2 ^= ((uint64_t)tail[ 8]) << 0; + k2 *= c2; k2 = ROTL64(k2,33); k2 *= c1; h2 ^= k2; + + case 8: k1 ^= ((uint64_t)tail[ 7]) << 56; + case 7: k1 ^= ((uint64_t)tail[ 6]) << 48; + case 6: k1 ^= ((uint64_t)tail[ 5]) << 40; + case 5: k1 ^= ((uint64_t)tail[ 4]) << 32; + case 4: k1 ^= ((uint64_t)tail[ 3]) << 24; + case 3: k1 ^= ((uint64_t)tail[ 2]) << 16; + case 2: k1 ^= ((uint64_t)tail[ 1]) << 8; + case 1: k1 ^= ((uint64_t)tail[ 0]) << 0; + k1 *= c1; k1 = ROTL64(k1,31); k1 *= c2; h1 ^= k1; + }; + + //---------- + // finalization + + h1 ^= len; h2 ^= len; + + h1 += h2; + h2 += h1; + + h1 = fmix64(h1); + h2 = fmix64(h2); + + h1 += h2; + h2 += h1; + + ((uint64_t*)out)[0] = h1; + ((uint64_t*)out)[1] = h2; +} + +//----------------------------------------------------------------------------- + diff --git a/twml/libtwml/src/lib/optim.cpp b/twml/libtwml/src/lib/optim.cpp new file mode 100644 index 000000000..7db36c26d --- /dev/null +++ b/twml/libtwml/src/lib/optim.cpp @@ -0,0 +1,274 @@ +#include "internal/interpolate.h" +#include "internal/error.h" +#include + +namespace twml { + template + void mdlInfer(Tensor &output_keys, Tensor &output_vals, + const Tensor &input_keys, const Tensor &input_vals, + const Tensor &bin_ids, + const Tensor &bin_vals, + const Tensor &feature_offsets, + bool return_bin_indices) { + auto okeysData = output_keys.getData(); + auto ovalsData = output_vals.getData(); + uint64_t okeysStride = output_keys.getStride(0); + uint64_t ovaluesStride = output_vals.getStride(0); + + auto ikeysData = input_keys.getData(); + auto ivalsData = input_vals.getData(); + uint64_t ikeysStride = input_keys.getStride(0); + uint64_t ivaluesStride = input_vals.getStride(0); + + auto xsData = bin_vals.getData(); + auto ysData = bin_ids.getData(); + uint64_t xsStride = bin_vals.getStride(0); + uint64_t ysStride = bin_ids.getStride(0); + + auto offsetData = feature_offsets.getData(); + + uint64_t size = input_keys.getDim(0); + uint64_t total_bins = bin_ids.getNumElements(); + uint64_t fsize = feature_offsets.getNumElements(); + + for (uint64_t i = 0; i < size; i++) { + int64_t ikey = ikeysData[i * ikeysStride] - TWML_INDEX_BASE; + T val = ivalsData[i * ivaluesStride]; + if (ikey == -1) { + ovalsData[i * ovaluesStride] = val; + continue; + } + + // Perform interpolation + uint64_t offset = offsetData[ikey]; + uint64_t next_offset = (ikey == (int64_t)(fsize - 1)) ? total_bins : offsetData[ikey + 1]; + uint64_t mainSize = next_offset - offset; + + const T *lxsData = xsData + offset; + const int64_t *lysData = ysData + offset; + int64_t okey = interpolation(lxsData, xsStride, + lysData, ysStride, + val, mainSize, NEAREST, 0, + return_bin_indices); + okeysData[i * okeysStride] = okey + TWML_INDEX_BASE; + ovalsData[i * ovaluesStride] = 1; + } + } + + void mdlInfer(Tensor &output_keys, Tensor &output_vals, + const Tensor &input_keys, const Tensor &input_vals, + const Tensor &bin_ids, + const Tensor &bin_vals, + const Tensor &feature_offsets, + bool return_bin_indices) { + if (input_keys.getType() != TWML_TYPE_INT64) { + throw twml::Error(TWML_ERR_TYPE, "input_keys must be a Long Tensor"); + } + + if (output_keys.getType() != TWML_TYPE_INT64) { + throw twml::Error(TWML_ERR_TYPE, "output_keys must be a Long Tensor"); + } + + if (bin_ids.getType() != TWML_TYPE_INT64) { + throw twml::Error(TWML_ERR_TYPE, "bin_ids must be a Long Tensor"); + } + + if (feature_offsets.getType() != TWML_TYPE_INT64) { + throw twml::Error(TWML_ERR_TYPE, "bin_ids must be a Long Tensor"); + } + + if (input_vals.getType() != bin_vals.getType()) { + throw twml::Error(TWML_ERR_TYPE, + "Data type of input_vals does not match type of bin_vals"); + } + + if (bin_vals.getNumDims() != 1) { + throw twml::Error(TWML_ERR_SIZE, + "bin_vals must be 1 Dimensional"); + } + + if (bin_ids.getNumDims() != 1) { + throw twml::Error(TWML_ERR_SIZE, + "bin_ids must be 1 Dimensional"); + } + + if (bin_vals.getNumElements() != bin_ids.getNumElements()) { + throw twml::Error(TWML_ERR_SIZE, + "Dimensions of bin_vals and bin_ids do not match"); + } + + if (feature_offsets.getStride(0) != 1) { + throw twml::Error(TWML_ERR_SIZE, + "feature_offsets must be contiguous"); + } + + switch (input_vals.getType()) { + case TWML_TYPE_FLOAT: + twml::mdlInfer(output_keys, output_vals, + input_keys, input_vals, + bin_ids, bin_vals, feature_offsets, + return_bin_indices); + break; + case TWML_TYPE_DOUBLE: + twml::mdlInfer(output_keys, output_vals, + input_keys, input_vals, + bin_ids, bin_vals, feature_offsets, + return_bin_indices); + break; + default: + throw twml::Error(TWML_ERR_TYPE, + "Unsupported datatype for mdlInfer"); + } + } + + const int DEFAULT_INTERPOLATION_LOWEST = 0; + /** + * @param output tensor to hold linear or nearest interpolation output. + * This function does not allocate space. + * The output tensor must have space allcoated. + * @param input input tensor; size must match output. + * input is assumed to have size [batch_size, number_of_labels]. + * @param xs the bins. + * @param ys the values for the bins. + * @param mode: linear or nearest InterpolationMode. + * linear is used for isotonic calibration. + * nearest is used for MDL calibration and MDL inference. + * + * @return Returns nothing. Output is stored into the output tensor. + * + * This is used by IsotonicCalibration inference. + */ + template + void interpolation( + Tensor output, + const Tensor input, + const Tensor xs, + const Tensor ys, + const InterpolationMode mode) { + // Sanity check: input and output should have two dims. + if (input.getNumDims() != 2 || output.getNumDims() != 2) { + throw twml::Error(TWML_ERR_TYPE, + "input and output should have 2 dimensions."); + } + + // Sanity check: input and output size should match. + for (int i = 0; i < input.getNumDims(); i++) { + if (input.getDim(i) != output.getDim(i)) { + throw twml::Error(TWML_ERR_TYPE, + "input and output mismatch in size."); + } + } + + // Sanity check: number of labels in input should match + // number of labels in xs / ys. + if (input.getDim(1) != xs.getDim(0) + || input.getDim(1) != ys.getDim(0)) { + throw twml::Error(TWML_ERR_TYPE, + "input, xs, ys should have the same number of labels."); + } + + const uint64_t inputStride0 = input.getStride(0); + const uint64_t inputStride1 = input.getStride(1); + const uint64_t outputStride0 = output.getStride(0); + const uint64_t outputStride1 = output.getStride(1); + const uint64_t xsStride0 = xs.getStride(0); + const uint64_t xsStride1 = xs.getStride(1); + const uint64_t ysStride0 = ys.getStride(0); + const uint64_t ysStride1 = ys.getStride(1); + const uint64_t mainSize = xs.getDim(1); + + // for each value in the input matrix, compute output value by + // calling interpolation. + auto inputData = input.getData(); + auto outputData = output.getData(); + auto xsData = xs.getData(); + auto ysData = ys.getData(); + + for (uint64_t i = 0; i < input.getDim(0); i++) { + for (uint64_t j = 0; j < input.getDim(1); j++) { + const T val = inputData[i * inputStride0 + j * inputStride1]; + const T *lxsData = xsData + j * xsStride0; + const T *lysData = ysData + j * ysStride0; + const T res = interpolation( + lxsData, xsStride1, + lysData, ysStride1, + val, + mainSize, + mode, + DEFAULT_INTERPOLATION_LOWEST); + outputData[i * outputStride0 + j * outputStride1] = res; + } + } + } + + void linearInterpolation( + Tensor output, + const Tensor input, + const Tensor xs, + const Tensor ys) { + switch (input.getType()) { + case TWML_TYPE_FLOAT: + twml::interpolation(output, input, xs, ys, LINEAR); + break; + case TWML_TYPE_DOUBLE: + twml::interpolation(output, input, xs, ys, LINEAR); + break; + default: + throw twml::Error(TWML_ERR_TYPE, + "Unsupported datatype for linearInterpolation."); + } + } + + void nearestInterpolation( + Tensor output, + const Tensor input, + const Tensor xs, + const Tensor ys) { + switch (input.getType()) { + case TWML_TYPE_FLOAT: + twml::interpolation(output, input, xs, ys, NEAREST); + break; + case TWML_TYPE_DOUBLE: + twml::interpolation(output, input, xs, ys, NEAREST); + break; + default: + throw twml::Error(TWML_ERR_TYPE, + "Unsupported datatype for nearestInterpolation."); + } + } +} // namespace twml + +twml_err twml_optim_mdl_infer(twml_tensor output_keys, + twml_tensor output_vals, + const twml_tensor input_keys, + const twml_tensor input_vals, + const twml_tensor bin_ids, + const twml_tensor bin_vals, + const twml_tensor feature_offsets, + bool return_bin_indices) { + HANDLE_EXCEPTIONS( + using namespace twml; + mdlInfer(*getTensor(output_keys), + *getTensor(output_vals), + *getConstTensor(input_keys), + *getConstTensor(input_vals), + *getConstTensor(bin_ids), + *getConstTensor(bin_vals), + *getConstTensor(feature_offsets), + return_bin_indices);); + return TWML_ERR_NONE; +} + +twml_err twml_optim_nearest_interpolation( + twml_tensor output, + const twml_tensor input, + const twml_tensor xs, + const twml_tensor ys) { + HANDLE_EXCEPTIONS( + using namespace twml; + nearestInterpolation(*getTensor(output), + *getConstTensor(input), + *getConstTensor(xs), + *getConstTensor(ys));); + return TWML_ERR_NONE; +} diff --git a/twml/libtwml/src/lib/utf_converter.cpp b/twml/libtwml/src/lib/utf_converter.cpp new file mode 100644 index 000000000..5c943f3e3 --- /dev/null +++ b/twml/libtwml/src/lib/utf_converter.cpp @@ -0,0 +1,53 @@ +#include "internal/utf_converter.h" + +ssize_t utf8_to_utf16(const uint8_t *in, uint64_t in_len, uint16_t *out, uint64_t max_out) { + uint64_t num_out = 0; + uint64_t num_in = 0; + while (num_in < in_len) { + uint32_t uni; + uint64_t todo; + uint8_t ch = in[num_in]; + num_in++; + if (ch <= 0x7F) { + uni = ch; + todo = 0; + } else if (ch <= 0xBF) { + return -1; + } else if (ch <= 0xDF) { + uni = ch & 0x1F; + todo = 1; + } else if (ch <= 0xEF) { + uni = ch & 0x0F; + todo = 2; + } else if (ch <= 0xF7) { + uni = ch & 0x07; + todo = 3; + } else { + return -1; + } + for (uint64_t j = 0; j < todo; ++j) { + if (num_in == in_len) return -1; + uint8_t ch = in[num_in]; + num_in++; + if (ch < 0x80 || ch > 0xBF) return -1; + uni <<= 6; + uni += ch & 0x3F; + } + if (uni >= 0xD800 && uni <= 0xDFFF) return -1; + if (uni > 0x10FFFF) return -1; + if (uni <= 0xFFFF) { + if (num_out == max_out) return -1; + out[num_out] = uni; + num_out++; + } else { + uni -= 0x10000; + if (num_out + 1 >= max_out) return -1; + out[num_out] = (uni >> 10) + 0xD800; + out[num_out + 1] = (uni & 0x3FF) + 0xDC00; + num_out += 2; + } + } + if (num_out == max_out) return -1; + out[num_out] = 0; + return num_out; +} diff --git a/twml/libtwml/src/ops/CMakeLists.txt b/twml/libtwml/src/ops/CMakeLists.txt new file mode 100644 index 000000000..e2feaff23 --- /dev/null +++ b/twml/libtwml/src/ops/CMakeLists.txt @@ -0,0 +1,79 @@ +set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}) +cmake_minimum_required(VERSION 2.8 FATAL_ERROR) +cmake_policy(VERSION 2.8) +set(CMAKE_MACOSX_RPATH 1) + +file(GLOB_RECURSE sources *.cpp) + +set (CMAKE_CXX_FLAGS "-Wall -std=c++11 -fno-stack-protector ${CMAKE_CXX_FLAGS}") + +execute_process( + COMMAND + $ENV{LIBTWML_HOME}/src/ops/scripts/get_inc.sh + RESULT_VARIABLE + TF_RES + OUTPUT_VARIABLE + TF_INC) + +if (NOT (${TF_RES} EQUAL "0")) + message(${TF_RES}) + message(FATAL_ERROR "Failed to get include path for tensorflow") +endif() + +execute_process( + COMMAND + $ENV{LIBTWML_HOME}/src/ops/scripts/get_lib.sh + RESULT_VARIABLE + TF_RES + OUTPUT_VARIABLE + TF_LIB) + +if (NOT (${TF_RES} EQUAL "0")) + message(${TF_RES}) + message(FATAL_ERROR "Failed to get lib path for tensorflow") +endif() + +find_path( + TWML_INC + NAMES "twml.h" + PATHS $ENV{LIBTWML_HOME}/include) + +add_library(twml_tf MODULE ${sources}) + +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "$ENV{LIBTWML_HOME}/cmake") + +if (UNIX) + if (APPLE) + set (CMAKE_CXX_FLAGS "-undefined dynamic_lookup -stdlib=libc++ ${CMAKE_CXX_FLAGS}") + # -Wl,-all_load ensures symbols not used by twml_tf are also included. + # -Wl,-noall_load limits the scope of the previous flag. + set (LINK_ALL_OPTION "-Wl,-all_load") + set (NO_LINK_ALL_OPTION "-Wl,-noall_load") + set(TF_FRAMEWORK_LIB ${TF_LIB}/libtensorflow_framework.1.dylib) + else() + # -Wl,--whole-archive ensures symbols not used by twml_tf are also included. + # -Wl,--no-whole-archive limits the scope of the previous flag. + set (LINK_ALL_OPTION "-Wl,--whole-archive") + set (NO_LINK_ALL_OPTION "-Wl,--no-whole-archive") + set(TF_FRAMEWORK_LIB ${TF_LIB}/libtensorflow_framework.so.1) + endif() +endif() + + +target_include_directories( + twml_tf + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${TWML_INC} + # TF_INC needs to be the last to avoid some weird white-spacing issues with generated Makefile. + ${TF_INC} # Needed because of some header files auto-generated during build time. + ${TF_INC}/external/nsync/public/ + ) + +target_link_libraries(twml_tf + PUBLIC + # Since we are using twml_tf as the "one" dynamic library, + # we want it to have the C function symbols needed for other functions as well. + ${LINK_ALL_OPTION} twml ${NO_LINK_ALL_OPTION} + ${TF_FRAMEWORK_LIB} + ) diff --git a/twml/libtwml/src/ops/add1.cpp b/twml/libtwml/src/ops/add1.cpp new file mode 100644 index 000000000..66281841a --- /dev/null +++ b/twml/libtwml/src/ops/add1.cpp @@ -0,0 +1,92 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" + +using namespace tensorflow; + +REGISTER_OP("Add1") +.Attr("T: {float, double, int32}") +.Input("input1: T") +.Output("output: T") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }); + + +template +class Add1 : public OpKernel { + public: + explicit Add1(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& input_tensor = context->input(0); + auto input = input_tensor.flat(); + + // Create an output tensor + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), + &output_tensor)); + auto output_flat = output_tensor->flat(); + + // Add 1 to input and assign to output + const int N = input.size(); + for (int i = 0; i < N; i++) { + output_flat(i) = input(i) + 1; + } + } +}; + + +REGISTER_OP("Add1Grad") +.Attr("T: {float, double, int32}") +.Input("grad_output: T") +.Output("grad_input: T") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }); + +template +class Add1Grad : public OpKernel { + public: + explicit Add1Grad(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& grad_output_tensor = context->input(0); + auto grad_output = grad_output_tensor.flat(); + + // Create an grad_input tensor + Tensor* grad_input_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, grad_output_tensor.shape(), + &grad_input_tensor)); + + auto grad_input_flat = grad_input_tensor->flat(); + + // Copy from grad_output to grad_input + const int N = grad_output.size(); + for (int i = 0; i < N; i++) { + grad_input_flat(i) = grad_output(i); + } + } +}; + +#define REGISTER(Type) \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("Add1") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + Add1); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("Add1Grad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + Add1Grad); \ + +REGISTER(float); +REGISTER(double); +REGISTER(int32); diff --git a/twml/libtwml/src/ops/batch_prediction_request.cpp b/twml/libtwml/src/ops/batch_prediction_request.cpp new file mode 100644 index 000000000..a83c3ebcf --- /dev/null +++ b/twml/libtwml/src/ops/batch_prediction_request.cpp @@ -0,0 +1,183 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include +#include "tensorflow_utils.h" +#include "resource_utils.h" + +REGISTER_OP("DecodeAndHashBatchPredictionRequest") +.Input("input_bytes: uint8") +.Attr("keep_features: list(int)") +.Attr("keep_codes: list(int)") +.Attr("decode_mode: int = 0") +.Output("hashed_data_record_handle: resource") +.SetShapeFn(shape_inference::ScalarShape) +.Doc(R"doc( +A tensorflow OP that decodes batch prediction request and creates a handle to the batch of hashed data records. + +Attr + keep_features: a list of int ids to keep. + keep_codes: their corresponding code. + decode_mode: integer, indicates which decoding method to use. Let a sparse continuous + have a feature_name and a dict of {name: value}. 0 indicates feature_ids are computed + as hash(name). 1 indicates feature_ids are computed as hash(feature_name, name) + shared_name: name used by the resource handle inside the resource manager. + container: name used by the container of the resources. + +shared_name and container are required when inheriting from ResourceOpKernel. + +Input + input_bytes: Input tensor containing the serialized batch of BatchPredictionRequest. + +Outputs + hashed_data_record_handle: A resource handle to the HashedDataRecordResource containing batch of HashedDataRecords. +)doc"); + +class DecodeAndHashBatchPredictionRequest : public OpKernel { + public: + explicit DecodeAndHashBatchPredictionRequest(OpKernelConstruction* context) + : OpKernel(context) { + std::vector keep_features; + std::vector keep_codes; + + OP_REQUIRES_OK(context, context->GetAttr("keep_features", &keep_features)); + OP_REQUIRES_OK(context, context->GetAttr("keep_codes", &keep_codes)); + OP_REQUIRES_OK(context, context->GetAttr("decode_mode", &m_decode_mode)); + + OP_REQUIRES(context, keep_features.size() == keep_codes.size(), + errors::InvalidArgument("keep keys and values must have same size.")); + +#ifdef USE_DENSE_HASH + m_keep_map.set_empty_key(0); +#endif // USE_DENSE_HASH + + for (uint64_t i = 0; i < keep_features.size(); i++) { + m_keep_map[keep_features[i]] = keep_codes[i]; + } + } + + private: + twml::Map m_keep_map; + int64 m_decode_mode; + + void Compute(OpKernelContext* context) override { + try { + HashedDataRecordResource *resource = nullptr; + OP_REQUIRES_OK(context, makeResourceHandle(context, 0, &resource)); + + // Store the input bytes in the resource so it isnt freed before the resource. + // This is necessary because we are not copying the contents for tensors. + resource->input = context->input(0); + const uint8_t *input_bytes = resource->input.flat().data(); + twml::HashedDataRecordReader reader; + twml::HashedBatchPredictionRequest bpr; + reader.setKeepMap(&m_keep_map); + reader.setBuffer(input_bytes); + reader.setDecodeMode(m_decode_mode); + bpr.decode(reader); + + resource->common = std::move(bpr.common()); + resource->records = std::move(bpr.requests()); + + // Each datarecord has a copy of common features. + // Initialize total_size by common_size * num_records + int64 common_size = static_cast(resource->common.totalSize()); + int64 num_records = static_cast(resource->records.size()); + int64 total_size = common_size * num_records; + for (const auto &record : resource->records) { + total_size += static_cast(record.totalSize()); + } + + resource->total_size = total_size; + resource->num_labels = 0; + resource->num_weights = 0; + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_KERNEL_BUILDER( + Name("DecodeAndHashBatchPredictionRequest").Device(DEVICE_CPU), + DecodeAndHashBatchPredictionRequest); + +REGISTER_OP("DecodeBatchPredictionRequest") +.Input("input_bytes: uint8") +.Attr("keep_features: list(int)") +.Attr("keep_codes: list(int)") +.Output("data_record_handle: resource") +.SetShapeFn(shape_inference::ScalarShape) +.Doc(R"doc( +A tensorflow OP that decodes batch prediction request and creates a handle to the batch of data records. + +Attr + keep_features: a list of int ids to keep. + keep_codes: their corresponding code. + shared_name: name used by the resource handle inside the resource manager. + container: name used by the container of the resources. + +shared_name and container are required when inheriting from ResourceOpKernel. + +Input + input_bytes: Input tensor containing the serialized batch of BatchPredictionRequest. + +Outputs + data_record_handle: A resource handle to the DataRecordResource containing batch of DataRecords. +)doc"); + +class DecodeBatchPredictionRequest : public OpKernel { + public: + explicit DecodeBatchPredictionRequest(OpKernelConstruction* context) + : OpKernel(context) { + std::vector keep_features; + std::vector keep_codes; + + OP_REQUIRES_OK(context, context->GetAttr("keep_features", &keep_features)); + OP_REQUIRES_OK(context, context->GetAttr("keep_codes", &keep_codes)); + + OP_REQUIRES(context, keep_features.size() == keep_codes.size(), + errors::InvalidArgument("keep keys and values must have same size.")); + +#ifdef USE_DENSE_HASH + m_keep_map.set_empty_key(0); +#endif // USE_DENSE_HASH + + for (uint64_t i = 0; i < keep_features.size(); i++) { + m_keep_map[keep_features[i]] = keep_codes[i]; + } + } + + private: + twml::Map m_keep_map; + + void Compute(OpKernelContext* context) override { + try { + DataRecordResource *resource = nullptr; + OP_REQUIRES_OK(context, makeResourceHandle(context, 0, &resource)); + + // Store the input bytes in the resource so it isnt freed before the resource. + // This is necessary because we are not copying the contents for tensors. + resource->input = context->input(0); + const uint8_t *input_bytes = resource->input.flat().data(); + twml::DataRecordReader reader; + twml::BatchPredictionRequest bpr; + reader.setKeepMap(&m_keep_map); + reader.setBuffer(input_bytes); + bpr.decode(reader); + + resource->common = std::move(bpr.common()); + resource->records = std::move(bpr.requests()); + + resource->num_weights = 0; + resource->num_labels = 0; + resource->keep_map = &m_keep_map; + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_KERNEL_BUILDER( + Name("DecodeBatchPredictionRequest").Device(DEVICE_CPU), + DecodeBatchPredictionRequest); diff --git a/twml/libtwml/src/ops/batch_prediction_request_v2.cpp b/twml/libtwml/src/ops/batch_prediction_request_v2.cpp new file mode 100644 index 000000000..3e89c9a0a --- /dev/null +++ b/twml/libtwml/src/ops/batch_prediction_request_v2.cpp @@ -0,0 +1,224 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include +#include +#include "tensorflow_utils.h" +#include "resource_utils.h" + +#include + +template +class DecodeBatchPredictionRequestKernel : public OpKernel { + public: + explicit DecodeBatchPredictionRequestKernel(OpKernelConstruction* context) + : OpKernel(context) { + std::vector keep_features; + std::vector keep_codes; + + std::vector label_features; + std::vector weight_features; + + OP_REQUIRES_OK(context, context->GetAttr("keep_features", &keep_features)); + OP_REQUIRES_OK(context, context->GetAttr("keep_codes", &keep_codes)); + + OP_REQUIRES_OK(context, context->GetAttr("label_features", &label_features)); + OP_REQUIRES_OK(context, context->GetAttr("weight_features", &weight_features)); + OP_REQUIRES_OK(context, context->GetAttr("decode_mode", &m_decode_mode)); + + OP_REQUIRES(context, keep_features.size() == keep_codes.size(), + errors::InvalidArgument("keep keys and values must have same size.")); + +#ifdef USE_DENSE_HASH + m_keep_map.set_empty_key(0); + m_labels_map.set_empty_key(0); + m_weights_map.set_empty_key(0); +#endif // USE_DENSE_HASH + + for (uint64_t i = 0; i < keep_features.size(); i++) { + m_keep_map[keep_features[i]] = keep_codes[i]; + } + + for (uint64_t i = 0; i < label_features.size(); i++) { + m_labels_map[label_features[i]] = i; + } + + for (uint64_t i = 0; i < weight_features.size(); i++) { + m_weights_map[weight_features[i]] = i; + } + } + + protected: + twml::Map m_keep_map; + twml::Map m_labels_map; + twml::Map m_weights_map; + int64 m_decode_mode; + + template + void Decode(OpKernelContext* context, ResourceType *resource) { + resource->input = context->input(0); + const uint8_t *input_bytes = getInputBytes(resource->input, 0); + int num_labels = static_cast(m_labels_map.size()); + int num_weights = static_cast(m_weights_map.size()); + + typename RecordType::Reader reader; + twml::GenericBatchPredictionRequest bpr(num_labels, num_weights); + + reader.setKeepMap(&m_keep_map); + reader.setLabelsMap(&m_labels_map); + reader.setBuffer(input_bytes); + reader.setDecodeMode(m_decode_mode); + // Do not set weight map if it is empty. This will take a faster path. + if (num_weights != 0) { + reader.setWeightsMap(&m_weights_map); + } + bpr.decode(reader); + + resource->common = std::move(bpr.common()); + resource->records = std::move(bpr.requests()); + + resource->num_labels = num_labels; + resource->num_weights = num_weights; + } +}; + + +REGISTER_OP("DecodeAndHashBatchPredictionRequestV2") +.Attr("InputType: {uint8, string}") +.Input("input_bytes: InputType") +.Attr("keep_features: list(int)") +.Attr("keep_codes: list(int)") +.Attr("label_features: list(int)") +.Attr("weight_features: list(int) = []") +.Attr("decode_mode: int = 0") +.Output("hashed_data_record_handle: resource") +.SetShapeFn(shape_inference::ScalarShape) +.Doc(R"doc( +A tensorflow OP that decodes a list/batch of data records and creates a handle to the batch of hashed data records. + +Compared to DecodeAndHashBatchPredictionRequest, DecodeAndHashBatchPredictionRequestV2 is used for training instead +of serving. Thus label_features and weight_features[optional] must be passed, and labels and weights are extracted in +the output. +DecodeAndHashBatchPredictionRequestV2 controls what DataRecords we want to process together in a batch in training. +For instance, we can put all instances for a query in the same batch when training a ranking model. +Notice that this OP was added separately to make sure we would not break the API for DecodeAndHashBatchPredictionRequest. +It requires some discussions if we merge the two ops into a single .cpp file in a future API revision. + +Attr + keep_features: a list of int ids to keep. + keep_codes: their corresponding code. + label_features: list of feature ids representing the labels. + weight_features: list of feature ids representing the weights. Defaults to empty list. + decode_mode: integer, indicates which decoding method to use. Let a sparse continuous + have a feature_name and a dict of {name: value}. 0 indicates feature_ids are computed + as hash(name). 1 indicates feature_ids are computed as hash(feature_name, name) + +Input + input_bytes: Input tensor containing the serialized batch of BatchPredictionRequest. + +Outputs + hashed_data_record_handle: A resource handle to the HashedDataRecordResource containing batch of HashedDataRecords. +)doc"); + +template +class DecodeAndHashBatchPredictionRequestV2 : + public DecodeBatchPredictionRequestKernel { + +public: + DecodeAndHashBatchPredictionRequestV2(OpKernelConstruction *context) + : DecodeBatchPredictionRequestKernel(context) { + } + + private: + void Compute(OpKernelContext* context) override { + try { + HashedDataRecordResource *resource = nullptr; + OP_REQUIRES_OK( + context, + makeResourceHandle(context, 0, &resource)); + + this->Decode(context, resource); + + // Each datarecord has a copy of common features. + // Initialize total_size by common_size * num_records + int64 common_size = static_cast(resource->common.totalSize()); + int64 num_records = static_cast(resource->records.size()); + int64 total_size = common_size * num_records; + for (const auto &record : resource->records) { + total_size += static_cast(record.totalSize()); + } + + resource->total_size = total_size; + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_OP("DecodeBatchPredictionRequestV2") +.Attr("InputType: {uint8, string}") +.Input("input_bytes: InputType") +.Attr("keep_features: list(int)") +.Attr("keep_codes: list(int)") +.Attr("label_features: list(int)") +.Attr("weight_features: list(int) = []") +.Attr("decode_mode: int = 0") +.Output("data_record_handle: resource") +.SetShapeFn(shape_inference::ScalarShape) +.Doc(R"doc( +A tensorflow OP that decodes batch prediction request and creates a handle to the batch of data records. + +Attr + keep_features: a list of int ids to keep. + keep_codes: their corresponding code. + shared_name: name used by the resource handle inside the resource manager. + label_features: list of feature ids representing the labels. + weight_features: list of feature ids representing the weights. Defaults to empty list. + decode_mode: reserved, do not use. + +Input + input_bytes: Input tensor containing the serialized batch of BatchPredictionRequest. + +Outputs + data_record_handle: A resource handle to the DataRecordResource containing batch of DataRecords. +)doc"); + + +template +class DecodeBatchPredictionRequestV2 : + public DecodeBatchPredictionRequestKernel { +public: + DecodeBatchPredictionRequestV2(OpKernelConstruction *context) + : DecodeBatchPredictionRequestKernel(context) { + } + +private: + void Compute(OpKernelContext* context) override { + try { + DataRecordResource *resource = nullptr; + OP_REQUIRES_OK( + context, + makeResourceHandle(context, 0, &resource)); + this->Decode(context, resource); + resource->keep_map = &(this->m_keep_map); + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +#define REGISTER_DECODE_OPS(InputType) \ + REGISTER_KERNEL_BUILDER( \ + Name("DecodeAndHashBatchPredictionRequestV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("InputType"), \ + DecodeAndHashBatchPredictionRequestV2); \ + REGISTER_KERNEL_BUILDER( \ + Name("DecodeBatchPredictionRequestV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("InputType"), \ + DecodeBatchPredictionRequestV2); \ + +REGISTER_DECODE_OPS(uint8) +REGISTER_DECODE_OPS(string) diff --git a/twml/libtwml/src/ops/batch_prediction_response_writer.cpp b/twml/libtwml/src/ops/batch_prediction_response_writer.cpp new file mode 100644 index 000000000..4876dd48a --- /dev/null +++ b/twml/libtwml/src/ops/batch_prediction_response_writer.cpp @@ -0,0 +1,82 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include +#include "tensorflow_utils.h" + +using namespace tensorflow; + +REGISTER_OP("BatchPredictionResponseWriter") +.Attr("T: {float, double}") +.Input("keys: int64") +.Input("values: T") +.Output("result: uint8") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( + +A tensorflow OP that packages keys and values into a BatchPredictionResponse. + +values: input feature value. (float/double) +keys: feature ids from the original BatchPredictionRequest. (int64) + +Outputs + bytes: output BatchPredictionRequest serialized using Thrift into a uint8 tensor. +)doc"); + +template +class BatchPredictionResponseWriter : public OpKernel { + public: + explicit BatchPredictionResponseWriter(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& keys = context->input(0); + const Tensor& values = context->input(1); + + try { + // Ensure the inner dimension matches. + if (values.dim_size(values.dims() - 1) != keys.dim_size(keys.dims() - 1)) { + throw std::runtime_error("The sizes of keys and values need to match"); + } + + // set inputs as twml::Tensor + const twml::Tensor in_keys_ = TFTensor_to_twml_tensor(keys); + const twml::Tensor in_values_ = TFTensor_to_twml_tensor(values); + // no tensors in this op + const twml::Tensor dummy_dense_keys_; + const std::vector dummy_dense_values_; + + // call constructor BatchPredictionResponse + twml::BatchPredictionResponse tempResult( + in_keys_, in_values_, dummy_dense_keys_, dummy_dense_values_); + + // determine the length of the result + int len = tempResult.encodedSize(); + TensorShape result_shape = {1, len}; + + // Create an output tensor, the size is determined by the content of input. + Tensor* result = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, result_shape, + &result)); + twml::Tensor out_result = TFTensor_to_twml_tensor(*result); + + // Call writer of BatchPredictionResponse + tempResult.write(out_result); + } catch(const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +#define REGISTER(Type) \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("BatchPredictionResponseWriter") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + BatchPredictionResponseWriter); \ + +REGISTER(float); +REGISTER(double); diff --git a/twml/libtwml/src/ops/batch_prediction_tensor_response_writer.cpp b/twml/libtwml/src/ops/batch_prediction_tensor_response_writer.cpp new file mode 100644 index 000000000..b98d23206 --- /dev/null +++ b/twml/libtwml/src/ops/batch_prediction_tensor_response_writer.cpp @@ -0,0 +1,81 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include +#include "tensorflow_utils.h" + +using namespace tensorflow; + +REGISTER_OP("BatchPredictionTensorResponseWriter") +.Attr("T: list({string, int32, int64, float, double})") +.Input("keys: int64") +.Input("values: T") +.Output("result: uint8") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( + +A tensorflow OP that packages keys and dense tensors into a BatchPredictionResponse. + +values: list of tensors +keys: feature ids from the original BatchPredictionRequest. (int64) + +Outputs + bytes: output BatchPredictionRequest serialized using Thrift into a uint8 tensor. +)doc"); + +class BatchPredictionTensorResponseWriter : public OpKernel { + public: + explicit BatchPredictionTensorResponseWriter(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& keys = context->input(0); + + try { + // set keys as twml::Tensor + const twml::Tensor in_keys_ = TFTensor_to_twml_tensor(keys); + + // check sizes + uint64_t num_keys = in_keys_.getNumElements(); + uint64_t num_values = context->num_inputs() - 1; + + OP_REQUIRES(context, num_values % num_keys == 0, + errors::InvalidArgument("Number of dense tensors not multiple of dense keys")); + + // set dense tensor values + std::vector in_values_; + for (int i = 1; i < context->num_inputs(); i++) { + in_values_.push_back(TFTensor_to_twml_raw_tensor(context->input(i))); + } + + // no continuous predictions in this op, only tensors + const twml::Tensor dummy_cont_keys_; + const twml::Tensor dummy_cont_values_; + + // call constructor BatchPredictionResponse + twml::BatchPredictionResponse tempResult( + dummy_cont_keys_, dummy_cont_values_, in_keys_, in_values_); + + // determine the length of the result + int len = tempResult.encodedSize(); + TensorShape result_shape = {1, len}; + + // Create an output tensor, the size is determined by the content of input. + Tensor* result = NULL; + OP_REQUIRES_OK(context, context->allocate_output(0, result_shape, + &result)); + twml::Tensor out_result = TFTensor_to_twml_tensor(*result); + + // Call writer of BatchPredictionResponse + tempResult.write(out_result); + } catch(const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_KERNEL_BUILDER( + Name("BatchPredictionTensorResponseWriter").Device(DEVICE_CPU), + BatchPredictionTensorResponseWriter); diff --git a/twml/libtwml/src/ops/binary_sparse_dense_matmul.cpp b/twml/libtwml/src/ops/binary_sparse_dense_matmul.cpp new file mode 100644 index 000000000..0a7f02af3 --- /dev/null +++ b/twml/libtwml/src/ops/binary_sparse_dense_matmul.cpp @@ -0,0 +1,330 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// TWML modified to optimize binary features: +// - Sparse tensor values are assumed to be binary, so only add operation is done +// rather than mul-add; +// - In house version of vectorization is used instead of Eigen; +// - Enable sharding and multithreading. + +#define EIGEN_USE_THREADS + +#include "binary_sparse_dense_matmul.h" +#include "binary_sparse_dense_matmul_impl.h" + +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +namespace shape_inference { +// TODO: The `a_value` is supposed to be all ones. +// Users should not call this op directly but to use it from `sparse_op` python library. +// To make it consistent with original op, the signature remains the same currently, +// we will think a better way to contrain correct use of this op. +// CX-18174 +REGISTER_OP("BinarySparseTensorDenseMatMul") + .Input("a_indices: Tindices") + .Input("a_values: T") + .Input("a_shape: int64") + .Input("b: T") + .Output("product: T") + .Attr("T: type") + .Attr("Tindices: {int32,int64} = DT_INT64") + .Attr("adjoint_a: bool = false") + .Attr("adjoint_b: bool = false") + .SetShapeFn([](InferenceContext* c) { + DimensionHandle unused_dim; + ShapeHandle unused; + ShapeHandle b; + ShapeHandle a_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // a_indices + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); // a_values + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &a_shape)); + TF_RETURN_IF_ERROR(c->WithRank(a_shape, 2, &a_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &b)); + + bool adjoint_a; + bool adjoint_b; + TF_RETURN_IF_ERROR(c->GetAttr("adjoint_a", &adjoint_a)); + TF_RETURN_IF_ERROR(c->GetAttr("adjoint_b", &adjoint_b)); + + DimensionHandle output_right = c->Dim(b, adjoint_b ? 0 : 1); + DimensionHandle output_left = c->Dim(a_shape, adjoint_a ? 1 : 0); + DimensionHandle inner_left = c->Dim(a_shape, adjoint_a ? 0 : 1); + DimensionHandle inner_right = c->Dim(b, adjoint_b ? 1 : 0); + TF_RETURN_IF_ERROR(c->Merge(inner_left, inner_right, &unused_dim)); + c->set_output(0, c->Matrix(output_left, output_right)); + return Status::OK(); + }); +} // namespace shape_inference + + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template +class BinarySparseTensorDenseMatMulOp : public OpKernel { + public: + explicit BinarySparseTensorDenseMatMulOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("adjoint_a", &adjoint_a_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("adjoint_b", &adjoint_b_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* a_indices; + const Tensor* a_values; + const Tensor* a_shape; + const Tensor* b; + OP_REQUIRES_OK(ctx, ctx->input("a_indices", &a_indices)); + OP_REQUIRES_OK(ctx, ctx->input("a_values", &a_values)); + OP_REQUIRES_OK(ctx, ctx->input("a_shape", &a_shape)); + OP_REQUIRES_OK(ctx, ctx->input("b", &b)); + + // Check that the dimensions of the two matrices are valid. + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b->shape()), + errors::InvalidArgument("Tensor 'b' is not a matrix")); + + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(a_shape->shape()), + errors::InvalidArgument("Tensor 'a_shape' is not a vector")); + + OP_REQUIRES( + ctx, a_shape->NumElements() == 2, + errors::InvalidArgument("Tensor 'a_shape' must have 2 elements")); + + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(a_values->shape()), + errors::InvalidArgument("Tensor 'a_values' is not a vector")); + + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a_indices->shape()), + errors::InvalidArgument("Tensor 'a_indices' is not a matrix")); + + const int64 nnz = a_indices->shape().dim_size(0); + OP_REQUIRES(ctx, nnz == a_values->NumElements(), + errors::InvalidArgument("Number of rows of a_indices does not " + "match number of entries in a_values")); + + OP_REQUIRES( + ctx, a_indices->shape().dim_size(1) == a_shape->NumElements(), + errors::InvalidArgument("Number of columns of a_indices does not match " + "number of entries in a_shape")); + + auto a_shape_t = a_shape->vec(); + const int64 outer_left = (adjoint_a_) ? a_shape_t(1) : a_shape_t(0); + const int64 outer_right = + (adjoint_b_) ? b->shape().dim_size(0) : b->shape().dim_size(1); + const int64 inner_left = (adjoint_a_) ? a_shape_t(0) : a_shape_t(1); + const int64 inner_right = + (adjoint_b_) ? b->shape().dim_size(1) : b->shape().dim_size(0); + + OP_REQUIRES( + ctx, inner_right == inner_left, + errors::InvalidArgument( + "Cannot multiply A and B because inner dimension does not match: ", + inner_left, " vs. ", inner_right, + ". Did you forget a transpose? " + "Dimensions of A: [", + a_shape_t(0), ", ", a_shape_t(1), + "). Dimensions of B: ", b->shape().DebugString())); + + TensorShape out_shape({outer_left, outer_right}); + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); + + if (out->NumElements() == 0) { + // If a has shape [0, x] or b has shape [x, 0], the output shape + // is a 0-element matrix, so there is nothing to do. + return; + } + + if (a_values->NumElements() == 0 || b->NumElements() == 0) { + // If a has shape [x, 0] and b has shape [0, y], the + // output shape is [x, y] where x and y are non-zero, so we fill + // the output with zeros. + out->flat().device(ctx->eigen_device()) = + out->flat().constant(T(0)); + return; + } + +#define MAYBE_ADJOINT(ADJ_A, ADJ_B) \ + if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) { \ + Status functor_status = functor::SparseTensorDenseMatMulFunctor< \ + Device, T, Tindices, ADJ_A, \ + ADJ_B>::Compute(ctx, a_indices, a_values, a_shape, b, out); \ + OP_REQUIRES_OK(ctx, functor_status); \ + } + + MAYBE_ADJOINT(false, false); + MAYBE_ADJOINT(false, true); + MAYBE_ADJOINT(true, false); + MAYBE_ADJOINT(true, true); + +#undef MAYBE_ADJOINT + } + + private: + bool adjoint_a_; + bool adjoint_b_; +}; + +#define REGISTER_CPU(TypeT, TypeIndex) \ + REGISTER_KERNEL_BUILDER( \ + Name("BinarySparseTensorDenseMatMul") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices") \ + .HostMemory("a_shape"), \ + BinarySparseTensorDenseMatMulOp); + +#define REGISTER_KERNELS_CPU(T) \ + REGISTER_CPU(T, int64); \ + REGISTER_CPU(T, int32) + +REGISTER_KERNELS_CPU(float); +REGISTER_KERNELS_CPU(double); +REGISTER_KERNELS_CPU(int32); +REGISTER_KERNELS_CPU(complex64); +REGISTER_KERNELS_CPU(complex128); + +namespace functor { + +namespace { +Status KOutOfBoundsError(int64 k, std::size_t i, int rhs_index_a, + std::size_t lhs_right) { + return errors::InvalidArgument("k (", k, ") from index[", i, ",", rhs_index_a, + "] out of bounds (>=", lhs_right, ")"); +} + +Status MOutOfBoundsError(int64 m, std::size_t i, int lhs_index_a, + int64 out_dim0) { + return errors::InvalidArgument("m (", m, ") from index[", i, ",", lhs_index_a, + "] out of bounds (>=", out_dim0, ")"); +} + +} // namespace + + +// The general functor just borrows the code from tf except that add is computed +// instead of mul-add. +template +struct SparseTensorDenseMatMulFunctor { + // Vectorize certain operations above this size. + static const std::size_t kNumVectorize = 32; + + static Status Compute(OpKernelContext* ctx, + const Tensor *a_indices, + const Tensor *a_values, + const Tensor *a_shape, + const Tensor *b, + Tensor *out) { + return EigenCompute(ctx->eigen_device(), out->matrix(), + a_indices->matrix(), a_values->vec(), + b->matrix()); + } + + static Status EigenCompute(const CPUDevice& d, typename TTypes::Matrix out, + typename TTypes::ConstMatrix a_indices, + typename TTypes::ConstVec a_values, + typename TTypes::ConstMatrix b) { + const std::size_t nnz = a_values.size(); + const std::size_t rhs_right = (ADJ_B ? b.dimension(0) : b.dimension(1)); + const std::size_t lhs_right = (ADJ_B ? b.dimension(1) : b.dimension(0)); + const int lhs_index_a = ADJ_A ? 1 : 0; + const int rhs_index_a = ADJ_A ? 0 : 1; + + out.setZero(); + + if (rhs_right < kNumVectorize) { + // Disable vectorization if the RHS of output is too small + auto maybe_adjoint_b = MaybeAdjoint(b); + + for (std::size_t i = 0; i < nnz; ++i) { + const Tindices m = internal::SubtleMustCopy(a_indices(i, lhs_index_a)); + const Tindices k = internal::SubtleMustCopy(a_indices(i, rhs_index_a)); + if (!FastBoundsCheck(k, lhs_right)) { + return KOutOfBoundsError(k, i, rhs_index_a, lhs_right); + } + if (!FastBoundsCheck(m, out.dimension(0))) { + return MOutOfBoundsError(m, i, lhs_index_a, out.dimension(0)); + } + for (std::size_t n = 0; n < rhs_right; ++n) { + const T b_value = maybe_adjoint_b(k, n); + out(m, n) += b_value; + } + } + } else { + // Vectorization via Eigen. + const int b_chip_index = ADJ_B ? 1 : 0; + +#define LOOP_NNZ(b_passed) \ + for (std::size_t i = 0; i < nnz; ++i) { \ + const Tindices m = internal::SubtleMustCopy(a_indices(i, lhs_index_a)); \ + const Tindices k = internal::SubtleMustCopy(a_indices(i, rhs_index_a)); \ + if (!FastBoundsCheck(k, lhs_right)) { \ + return KOutOfBoundsError(k, i, rhs_index_a, lhs_right); \ + } \ + if (!FastBoundsCheck(m, out.dimension(0))) { \ + return MOutOfBoundsError(m, i, lhs_index_a, out.dimension(0)); \ + } \ + out.template chip<0>(m) += b_passed.template chip(k); \ + } + + + if (ADJ_B) { + // Perform transpose and conjugation on B once, since we chip out B's + // columns in the nnz loop. + Eigen::array shuffle; // preserve dimension order + shuffle[0] = 1; shuffle[1] = 0; + Eigen::Tensor col_major_conj_b = + b.swap_layout().shuffle(shuffle).conjugate(); + LOOP_NNZ(col_major_conj_b); + } else { + LOOP_NNZ(b); + } +#undef LOOP_NNZ + } + return Status::OK(); + } +}; + + +// We have only specified and optimised the case with no matrix transpose, +// since it is the most typical usage in productions. +template +struct SparseTensorDenseMatMulFunctor { + static Status Compute(OpKernelContext* ctx, + const Tensor *a_indices, + const Tensor *a_values, + const Tensor *a_shape, + const Tensor *b, + Tensor *out) { + auto a_indices_ptr = a_indices->flat().data(); + auto b_ptr = b->flat().data(); + auto out_ptr = out->flat().data(); + const int64 nnz = a_indices->shape().dim_size(0); + const int64 outer_left = a_shape->vec()(0); + const int64 outer_right = b->shape().dim_size(1); + ParallelLookupAndSegmentSum(ctx, a_indices_ptr, b_ptr, nnz, + outer_left, outer_right, out_ptr); + return Status::OK(); + } +}; + +} // namespace functor + +} // namespace tensorflow diff --git a/twml/libtwml/src/ops/binary_sparse_dense_matmul.h b/twml/libtwml/src/ops/binary_sparse_dense_matmul.h new file mode 100644 index 000000000..92494af52 --- /dev/null +++ b/twml/libtwml/src/ops/binary_sparse_dense_matmul.h @@ -0,0 +1,75 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// TWML modified to optimize binary features +#ifndef TENSORFLOW_CORE_KERNELS_BINARY_SPARSE_TENSOR_DENSE_MATMUL_OP_H_ +#define TENSORFLOW_CORE_KERNELS_BINARY_SPARSE_TENSOR_DENSE_MATMUL_OP_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +namespace functor { + +template +struct SparseTensorDenseMatMulFunctor { + static EIGEN_ALWAYS_INLINE Status Compute( + const Device& d, typename TTypes::Matrix out, + typename TTypes::ConstMatrix a_indices, + typename TTypes::ConstVec a_values, typename TTypes::ConstMatrix b); +}; + +template +class MaybeAdjoint; + +template +class MaybeAdjoint { + public: + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE MaybeAdjoint(MATRIX m) : m_(m) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename MATRIX::Scalar operator()( + const typename MATRIX::Index i, const typename MATRIX::Index j) const { + return m_(i, j); + } + + private: + const MATRIX m_; +}; + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T MaybeConj(T v) { + return v; +} + +template +class MaybeAdjoint { + public: + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE MaybeAdjoint(MATRIX m) : m_(m) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename MATRIX::Scalar operator()( + const typename MATRIX::Index i, const typename MATRIX::Index j) const { + return Eigen::numext::conj(m_(j, i)); + } + + private: + const MATRIX m_; +}; + +} // end namespace functor +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BINARY_SPARSE_TENSOR_DENSE_MATMUL_OP_H_ diff --git a/twml/libtwml/src/ops/binary_sparse_dense_matmul_impl.h b/twml/libtwml/src/ops/binary_sparse_dense_matmul_impl.h new file mode 100644 index 000000000..db61647cb --- /dev/null +++ b/twml/libtwml/src/ops/binary_sparse_dense_matmul_impl.h @@ -0,0 +1,145 @@ +#ifndef TENSORFLOW_CORE_KERNELS_BINARY_SPARSE_TENSOR_DENSE_MATMUL_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_BINARY_SPARSE_TENSOR_DENSE_MATMUL_IMPL_H_ + +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/lib/core/threadpool.h" + +namespace tensorflow { +namespace functor { + +// `ConservativeShard` is adopted rather than `Shard` in tensorflow because the +// original `Shard` may generate number of shards more than the number of +// threads, which is not ideal for this case, as it may cause too much overhead. +static void ConservativeShard(int max_parallelism, thread::ThreadPool *workers, + int64 total, int64 cost_per_unit, + std::function work) { + if (total == 0) { + return; + } + max_parallelism = std::min(max_parallelism, workers->NumThreads()); + if (max_parallelism <= 1) { + // Just inline the whole work since we only have 1 thread (core). + work(0, total); + return; + } + cost_per_unit = std::max(1LL, cost_per_unit); + // We shard [0, total) into "num_shards" shards. + // 1 <= num_shards <= num worker threads + // + // If total * cost_per_unit is small, it is not worth shard too + // much. Let us assume each cost unit is 1ns, kMinCostPerShard=10000 + // is 10us. + static const int64 kMinCostPerShard = 10000; + const int num_shards = + std::max(1, std::min(static_cast(max_parallelism), + total * cost_per_unit / kMinCostPerShard)); + + // Each shard contains up to "block_size" units. [0, total) is sharded + // into: + // [0, block_size), [block_size, 2*block_size), ... + // The 1st shard is done by the caller thread and the other shards + // are dispatched to the worker threads. The last shard may be smaller than + // block_size. + const int64 block_size = (total + num_shards - 1) / num_shards; + if (block_size >= total) { + work(0, total); + return; + } + const int num_shards_used = (total + block_size - 1) / block_size; + BlockingCounter counter(num_shards_used - 1); + for (int64 start = block_size; start < total; start += block_size) { + auto limit = std::min(start + block_size, total); + workers->Schedule([&work, &counter, start, limit]() { + work(start, limit); // Compute the shard. + counter.DecrementCount(); // The shard is done. + }); + } + + // Inline execute the 1st shard. + work(0, std::min(block_size, total)); + counter.Wait(); +} + +static inline void VectorSum(float *a, const float *b, int n) { + for (int i = 0; i < n; ++i) { + a[i] += b[i]; + } +} + +// This func is to vectorize the computation of segment sum. +template +static void LookupAndSegmentSum(const Tindices *a_indices, const float *b, + int nnz, int outer_right, float *output) { + for (std::size_t i = 0; i < nnz; ++i) { + const Tindices m = a_indices[i * 2]; + const Tindices k = a_indices[i * 2 + 1]; + auto output_row_m = output + m * outer_right; + auto b_row_k = b + k * outer_right; + VectorSum(output_row_m, b_row_k, outer_right); + } +} + +// This func enables sharding and multithreading, it comes with an overhead of +// duplicating output buffer to achieve lock free output. So there should not +// be too many threads. +template +static void ParallelLookupAndSegmentSum(OpKernelContext *ctx, + const Tindices *a_indices, + const float *b, int nnz, int outer_left, + int outer_right, float *output) { + auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); + int out_size = outer_left * outer_right; + if (worker_threads.num_threads <= 1) { + memset(output, 0, out_size * sizeof(float)); + LookupAndSegmentSum(a_indices, b, + nnz, outer_right, + output); + return; + } + + // this is to make buffer align with kAllocatorAlignment + int padded_out_size = (out_size + (Allocator::kAllocatorAlignment - 1)) & + ~(Allocator::kAllocatorAlignment - 1); + std::size_t num_bytes = + (worker_threads.num_threads - 1) * padded_out_size * sizeof(float); + auto buffer = std::unique_ptr(reinterpret_cast( + port::AlignedMalloc(num_bytes, Allocator::kAllocatorAlignment))); + float *temp_out = buffer.get(); + + std::atomic thread_index(0); + + auto task = [&](int64 start, int64 limit) { + int local_thread_index = thread_index++; + float *buf_ptr = nullptr; + if (local_thread_index == 0) { + buf_ptr = output; + } else { + buf_ptr = temp_out + (local_thread_index - 1) * padded_out_size; + } + memset(buf_ptr, 0, out_size * sizeof(float)); + + LookupAndSegmentSum(a_indices + start * 2, b, + limit - start, outer_right, + buf_ptr); + }; + + int cost_per_unit = outer_right; + + // We don't use tensorflow shard func as tf may create more shards than + // number of threads. + ConservativeShard(worker_threads.num_threads, worker_threads.workers, nnz, + static_cast(cost_per_unit), task); + + for (int i = 1; i < thread_index; ++i) { + VectorSum(output, temp_out + (i - 1) * padded_out_size, out_size); + } +} + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BINARY_SPARSE_TENSOR_DENSE_MATMUL_IMPL_H_ \ No newline at end of file diff --git a/twml/libtwml/src/ops/block_format_dataset.cpp b/twml/libtwml/src/ops/block_format_dataset.cpp new file mode 100644 index 000000000..fdf4a9543 --- /dev/null +++ b/twml/libtwml/src/ops/block_format_dataset.cpp @@ -0,0 +1,243 @@ +#include "block_format_reader.h" + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/io/random_inputstream.h" + +#if !defined(DISABLE_ZLIB) +#include "tensorflow/core/lib/io/zlib_inputstream.h" +#endif + +#include + +#include +#include +#include + +using namespace tensorflow; + + +inline std::string stripPath(std::string const &file_name) { + const auto pos = file_name.find_last_of("/"); + if (pos == std::string::npos) return file_name; + return file_name.substr(pos + 1); +} + +inline std::string getExtension(std::string const &file_name) { + const auto stripped_file_name = stripPath(file_name); + const auto pos = stripPath(stripped_file_name).find_last_of("."); + if (pos == std::string::npos) return ""; + return stripped_file_name.substr(pos + 1); +} + +REGISTER_OP("BlockFormatDatasetV2") +.Input("filenames: string") +.Input("compression_type: string") +.Input("buffer_size: int64") +.Output("handle: variant") +.SetIsStateful() +.SetShapeFn(shape_inference::ScalarShape) +.Doc(R"doc( + +Creates a dataset for streaming BlockFormat data in compressed (e.g. gzip), uncompressed formats. +This op also has the ability stream a dataset containing files from multiple formats mentioned above. + +filenames: A scalar or vector containing the name(s) of the file(s) to be read. +compression_type: A scalar string denoting the compression type. Can be 'none', 'zlib', 'auto'. +buffer_size: A scalar denoting the buffer size to use during decompression. + +Outputs + handle: A handle to the dataset. This handle is later used to create an iterator to stream the data from the dataset. + +)doc"); + + +class BlockFormatDatasetV2 : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase **output) override { + const Tensor* filenames_tensor; + OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); + OP_REQUIRES( + ctx, filenames_tensor->dims() <= 1, + errors::InvalidArgument("`filenames` must be a scalar or a vector.")); + + const auto filenames_flat = filenames_tensor->flat(); + const int64 num_files = filenames_tensor->NumElements(); + std::vector filenames; + filenames.reserve(num_files); + std::copy(filenames_flat.data(), + filenames_flat.data() + num_files, + std::back_inserter(filenames)); + + string compression_type; + OP_REQUIRES_OK( + ctx, tensorflow::data::ParseScalarArgument( + ctx, "compression_type", &compression_type)); + + int64 buffer_size = -1; + OP_REQUIRES_OK( + ctx, tensorflow::data::ParseScalarArgument( + ctx, "buffer_size", &buffer_size)); + + OP_REQUIRES(ctx, buffer_size >= 0, + errors::InvalidArgument( + "`buffer_size` must be >= 0 (0 == no buffering)")); + + OP_REQUIRES(ctx, + compression_type == "auto" || + compression_type == "gz" || + compression_type == "", + errors::InvalidArgument("Unknown extension: ", compression_type)); + + *output = new Dataset(ctx, std::move(filenames), compression_type, buffer_size); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, + std::vector filenames, + std::string compression_type, + int64 buffer_size) + : DatasetBase(DatasetContext(ctx)), + compression_type_(compression_type), + buffer_size_(buffer_size), + filenames_(std::move(filenames)) + {} + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + return *dtypes; + } + + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}}); + return *shapes; + } + + string DebugString() const override { return "BlockFormatDatasetV2::Dataset"; } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* filenames = nullptr; + Node* compression_type = nullptr; + Node* buffer_size = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); + TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type)); + TF_RETURN_IF_ERROR( + b->AddScalar(buffer_size_, &buffer_size)); + TF_RETURN_IF_ERROR(b->AddDataset( + this, {filenames, compression_type, buffer_size}, output)); + return Status::OK(); + } + + private: + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::BlockFormat")})); + } + + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params ¶ms) + : DatasetIterator(params) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + do { + // We are currently processing a file, so try to read the next record. + if (reader_) { + Tensor result_tensor(cpu_allocator(), DT_STRING, {}); + Status s = reader_->ReadNext(&result_tensor.scalar()()); + if (s.ok()) { + out_tensors->emplace_back(std::move(result_tensor)); + *end_of_sequence = false; + return Status::OK(); + } else if (!errors::IsOutOfRange(s)) { + return s; + } + + // We have reached the end of the current file, so maybe + // move on to next file. + reader_.reset(); + ++current_file_index_; + } + + // Iteration ends when there are no more files to process. + if (current_file_index_ == dataset()->filenames_.size()) { + *end_of_sequence = true; + return Status::OK(); + } + + // Actually move on to next file. + const string& next_filename = + dataset()->filenames_[current_file_index_]; + + auto compression_type = dataset()->compression_type_; + int64 buffer_size = dataset()->buffer_size_; + + if (compression_type == "auto") { + compression_type = getExtension(next_filename); + } + + if (compression_type != "gz" && compression_type != "") { + return errors::InvalidArgument("Unknown extension: ", compression_type); + } + + tensorflow::Env* env = tensorflow::Env::Default(); + TF_CHECK_OK(env->NewRandomAccessFile(next_filename, &file_)); + + // RandomAccessInputstream defaults the second param to "false". + // The second parameter "false" is the key issue. + // "false" assumes the ownership of the file is elsewhere. + // But making that "true" causes segfaults down the line. + // So keep the ownership of "file_" in this class and clean up properly. + file_stream_.reset(new tensorflow::io::RandomAccessInputStream(file_.get(), false)); + + if (compression_type == "gz") { + // unpack_stream does not take ownership of file_stream_ +#if !defined(DISABLE_ZLIB) + unpack_stream_.reset(new tensorflow::io::ZlibInputStream( + file_stream_.get(), + buffer_size, + buffer_size, + tensorflow::io::ZlibCompressionOptions::GZIP())); + reader_.reset(new BlockFormatReader(unpack_stream_.get())); +#else + return errors::InvalidArgument("libtwml compiled without zlib support"); +#endif + } else { + unpack_stream_.reset(nullptr); + reader_.reset(new BlockFormatReader(file_stream_.get())); + } + } while (true); + } + + private: + mutex mu_; + uint64_t current_file_index_ GUARDED_BY(mu_) = 0; + std::unique_ptr file_; + std::unique_ptr file_stream_; + std::unique_ptr unpack_stream_; + std::unique_ptr reader_ GUARDED_BY(mu_); + }; + + const std::string compression_type_; + const int64 buffer_size_; + const std::vector filenames_; + }; +}; + +REGISTER_KERNEL_BUILDER( + Name("BlockFormatDatasetV2") + .Device(DEVICE_CPU), + BlockFormatDatasetV2); diff --git a/twml/libtwml/src/ops/block_format_reader.h b/twml/libtwml/src/ops/block_format_reader.h new file mode 100644 index 000000000..29450cc03 --- /dev/null +++ b/twml/libtwml/src/ops/block_format_reader.h @@ -0,0 +1,50 @@ +#pragma once + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/lib/io/random_inputstream.h" + +#include + +#include + +using tensorflow::int64; +using tensorflow::Status; +using std::string; + +class BlockFormatReader : twml::BlockFormatReader { + public: + explicit BlockFormatReader(tensorflow::io::InputStreamInterface *stream) + : twml::BlockFormatReader() , stream_(stream) { + } + + // Read the next record. + // Returns OK on success, + // Returns OUT_OF_RANGE for end of file, or something else for an error. + Status ReadNext(string* record) { + if (this->next()) { + return stream_->ReadNBytes(this->current_size(), record); + } + return tensorflow::errors::OutOfRange("eof"); + } + + uint64_t read_bytes(void *dest, int size, int count) { + uint64_t bytesToRead = size * count; + std::string current; + // TODO: Try to merge ReadNBytes and the memcpy below + // ReadNBytes performs a memory copy already. + Status status = stream_->ReadNBytes(bytesToRead, ¤t); + if (!status.ok()) { + return 0; + } + memcpy(dest, current.c_str(), bytesToRead); + return count; + } + + private: + tensorflow::io::InputStreamInterface *stream_; + TF_DISALLOW_COPY_AND_ASSIGN(BlockFormatReader); +}; diff --git a/twml/libtwml/src/ops/compress_sample_ids.cpp b/twml/libtwml/src/ops/compress_sample_ids.cpp new file mode 100644 index 000000000..3053de471 --- /dev/null +++ b/twml/libtwml/src/ops/compress_sample_ids.cpp @@ -0,0 +1,138 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include // std::fill_n + +using namespace tensorflow; + +REGISTER_OP("CompressSampleIds") +.Attr("T: {int32}") +.Input("input: T") +.Output("output: T") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->Vector(c->kUnknownDim)); + return Status::OK(); + }); + + +template +class CompressSampleIds : public OpKernel { + public: + explicit CompressSampleIds(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& input_tensor = context->input(0); + auto input = input_tensor.flat(); + const int N = input.size(); + + // Check for improper input + bool error = (N > 0 && input(0) < 0); + for (int i = 1; !error && i < N; i++) { + error = input(i - 1) > input(i); + } + + OP_REQUIRES( + context, !error, + errors::InvalidArgument( + "Error in CompressSampleIds. SampleIds must be non-negative and non-decreasing" + ) + ); + + // choose output size, either last input element + 1, or 0 + int output_size = 0; + if (N > 0) { + output_size = input(N - 1) + 1; + } + + // Create an output tensor + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK( + context, + context->allocate_output(0, TensorShape({output_size}), &output_tensor) + ); + auto output_flat = output_tensor->flat(); + + // Zero-initialize output + for (int i = 0; i < output_size; i++) { + output_flat(i) = 0; + } + + // count how many of each input element + for (int i = 0; i < N; i++) { + output_flat(input(i)) ++; + } + } +}; + +REGISTER_OP("DecompressSampleIds") +.Attr("T: {int32}") +.Input("input: T") +.Output("output: T") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->Vector(c->kUnknownDim)); + return Status::OK(); + }); + + +template +class DecompressSampleIds : public OpKernel { + public: + explicit DecompressSampleIds(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& input_tensor = context->input(0); + auto input = input_tensor.flat(); + const int N = input.size(); + + // Check for improper input + bool error = false; + int output_size = 0; + for (int i = 0; !error && i < N; i++) { + error = input(i) < 0; + output_size += input(i); + } + + OP_REQUIRES( + context, !error, + errors::InvalidArgument( + "Error in DecompressSampleIds. Inputs must be non-negative." + ) + ); + + // Create an output tensor + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK( + context, + context->allocate_output(0, TensorShape({output_size}),&output_tensor) + ); + auto output_flat = output_tensor->flat(); + + T *output_data = output_flat.data(); + for (int current_sample = 0; current_sample < N; current_sample++) { + std::fill_n(output_data, input(current_sample), current_sample); + output_data += input(current_sample); + } + } +}; + + + +#define REGISTER(Type) \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("CompressSampleIds") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + CompressSampleIds); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("DecompressSampleIds") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + DecompressSampleIds); \ + \ + +REGISTER(int32); diff --git a/twml/libtwml/src/ops/contrib/get_substrings.cpp b/twml/libtwml/src/ops/contrib/get_substrings.cpp new file mode 100644 index 000000000..8cd167e65 --- /dev/null +++ b/twml/libtwml/src/ops/contrib/get_substrings.cpp @@ -0,0 +1,116 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include +#include "../tensorflow_utils.h" +#include "../resource_utils.h" + +#include +#include + +using std::string; + +void join(const std::set& v, char c, string& s) { + s.clear(); + std::set::iterator it = v.begin(); + while (it != v.end()) { + s += *it; + it++; + if (it != v.end()) s+= c; + } +} + +// cpp function that computes substrings of a given word +std::string computeSubwords(std::string word, int32_t minn, int32_t maxn) { + std::string word2 = "<" + word + ">"; + std::set ngrams; + std::string s; + ngrams.insert(word); + ngrams.insert(word2); + for (size_t i = 0; i < word2.size(); i++) { + if ((word2[i] & 0xC0) == 0x80) continue; + for (size_t j = minn; i+j <= word2.size() && j <= maxn; j++) { + ngrams.insert(word2.substr(i, j)); + } + } + join(ngrams, ';', s); + ngrams.clear(); + return s; +} + +// tf-op function that computes substrings for a given tensor of words +template< typename ValueType> + +void ComputeSubStringsTensor(OpKernelContext *context, int32 min_n, int32 max_n) { + try { + const Tensor& values = context->input(0); + + auto values_flat = values.flat(); + + // batch_size from input_size : + const int batch_size = values_flat.size(); + + // define the output tensor + Tensor* substrings = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, values.shape(), &substrings)); + + auto substrings_flat = substrings->flat(); + // compute substrings for the given tensor values + for (int64 i = 0; i < batch_size; i++) { + substrings_flat(i) = computeSubwords(values_flat(i), min_n, max_n); + } + } + catch (const std::exception &err) { + context->CtxFailureWithWarning(errors::InvalidArgument(err.what())); + } +} + +REGISTER_OP("GetSubstrings") +.Attr("ValueType: {string}") +.Attr("min_n: int") +.Attr("max_n: int") +.Input("values: ValueType") +.Output("substrings: ValueType") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }).Doc(R"doc( + +A tensorflow OP to convert word to substrings of length between min_n and max_n. + +Attr + min_n,max_n: The size of the substrings. + +Input + values: 1D input tensor containing the values. + +Outputs + substrings: A string tensor where substrings are joined by ";". +)doc"); + +template +class GetSubstrings : public OpKernel { + public: + explicit GetSubstrings(OpKernelConstruction *context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("min_n", &min_n)); + OP_REQUIRES_OK(context, context->GetAttr("max_n", &max_n)); + } + + private: + int32 min_n; + int32 max_n; + void Compute(OpKernelContext *context) override { + ComputeSubStringsTensor(context, min_n, max_n); + } +}; + + +#define REGISTER_SUBSTRINGS(ValueType) \ + REGISTER_KERNEL_BUILDER( \ + Name("GetSubstrings") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("ValueType"), \ + GetSubstrings); \ + +REGISTER_SUBSTRINGS(string) diff --git a/twml/libtwml/src/ops/data_record.cpp b/twml/libtwml/src/ops/data_record.cpp new file mode 100644 index 000000000..71ea72ac4 --- /dev/null +++ b/twml/libtwml/src/ops/data_record.cpp @@ -0,0 +1,1891 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include +#include +#include +#include "tensorflow_utils.h" +#include "resource_utils.h" + +#include + +using std::string; + +REGISTER_OP("DecodeDataRecord") +.Attr("InputType: {uint8, string}") +.Attr("keep_features: list(int)") +.Attr("keep_codes: list(int)") +.Attr("label_features: list(int)") +.Attr("weight_features: list(int) = []") +.Input("input_bytes: InputType") +.Output("data_record_handle: resource") +.SetShapeFn(shape_inference::ScalarShape) +.Doc(R"doc( +A tensorflow OP that creates a handle for the datarecord. + +Attr + keep_features: a list of int ids to keep. + keep_codes: their corresponding code. + label_features: list of feature ids representing the labels. + weight_features: list of feature ids representing the weights. Defaults to empty list. + shared_name: name used by the resource handle inside the resource manager. + container: name used by the container of the resources. + +shared_name and container are required when inheriting from ResourceOpKernel. + +Input + input_bytes: Input tensor containing the serialized batch of HashedDataRecords. + +Outputs + data_record_handle: A resource handle to the DataRecord struct. +)doc"); + +template +class DecodeDataRecord : public OpKernel { + public: + explicit DecodeDataRecord(OpKernelConstruction* context) + : OpKernel(context) { + std::vector keep_features; + std::vector keep_codes; + + std::vector label_features; + std::vector weight_features; + + OP_REQUIRES_OK(context, context->GetAttr("keep_features", &keep_features)); + OP_REQUIRES_OK(context, context->GetAttr("keep_codes", &keep_codes)); + OP_REQUIRES_OK(context, context->GetAttr("label_features", &label_features)); + OP_REQUIRES_OK(context, context->GetAttr("weight_features", &weight_features)); + + OP_REQUIRES(context, keep_features.size() == keep_codes.size(), + errors::InvalidArgument("keep keys and values must have same size.")); + +#ifdef USE_DENSE_HASH + m_keep_map.set_empty_key(0); + m_labels_map.set_empty_key(0); + m_weights_map.set_empty_key(0); +#endif // USE_DENSE_HASH + + for (uint64_t i = 0; i < keep_features.size(); i++) { + m_keep_map[keep_features[i]] = keep_codes[i]; + } + + for (uint64_t i = 0; i < label_features.size(); i++) { + m_labels_map[label_features[i]] = i; + } + + for (uint64_t i = 0; i < weight_features.size(); i++) { + m_weights_map[weight_features[i]] = i; + } + } + + private: + twml::Map m_keep_map; + twml::Map m_labels_map; + twml::Map m_weights_map; + + void Compute(OpKernelContext* context) override { + try { + DataRecordResource *resource = nullptr; + OP_REQUIRES_OK(context, makeResourceHandle(context, 0, &resource)); + + // Store the input bytes in the resource so it isnt freed before the resource. + // This is necessary because we are not copying the contents for tensors. + resource->input = context->input(0); + int batch_size = getBatchSize(resource->input); + int num_labels = static_cast(m_labels_map.size()); + int num_weights = static_cast(m_weights_map.size()); + + twml::DataRecordReader reader; + reader.setKeepMap(&m_keep_map); + reader.setLabelsMap(&m_labels_map); + + // Do not set weight map if it is empty. This will take a faster path. + if (num_weights != 0) { + reader.setWeightsMap(&m_weights_map); + } + + resource->records.clear(); + resource->records.reserve(batch_size); + for (int i = 0; i < batch_size; i++) { + resource->records.emplace_back(num_labels, num_weights); + } + + for (int64 id = 0; id < batch_size; id++) { + const uint8_t *input_bytes = getInputBytes(resource->input, id); + reader.setBuffer(input_bytes); + // decode the reader + resource->records[id].decode(reader); + } + // This should be fine because m_keep_map should never go out of scope. + resource->keep_map = &m_keep_map; + resource->num_weights = num_weights; + resource->num_labels = num_labels; + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +int64_t count_if_exists(const twml::DataRecord::BinaryFeatures &set, + const twml::Map *const keep_map) { + int64_t count = 0; + for (const auto &key : set) { + if (keep_map->find(key) == keep_map->end()) continue; + count++; + } + return count; +} + +// This works for continuous, discrete, and string features +template +int64_t count_if_exists(const twml::Map &map, + const twml::Map *const keep_map) { + int64_t count = 0; + for (const auto &elem : map) { + if (keep_map->find(elem.first) == keep_map->end()) continue; + count++; + } + return count; +} + +int64_t count_if_exists(const twml::DataRecord::SparseBinaryFeatures &map, + const twml::Map *const keep_map) { + int64_t count = 0; + for (const auto &elem : map) { + if (keep_map->find(elem.first) == keep_map->end()) continue; + count += elem.second.size(); + } + return count; +} + +int64_t count_if_exists(const twml::DataRecord::SparseContinuousFeatures &map, + const twml::Map *const keep_map) { + int64_t count = 0; + for (const auto &elem : map) { + if (keep_map->find(elem.first) == keep_map->end()) continue; + count += elem.second.size(); + } + return count; +} + +REGISTER_OP("GetBinaryFeatures") +.Input("data_record_handle: resource") +.Output("ids: int64") +.Output("keys: int64") +.Output("values: float") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that reads binary features +Input + data_record_handle: Resource handle to DataRecord + +Outputs + ids: ids specifies the index of the records[id] in the batch (int64) + keys: DataRecord keys (int64) + values: always set to 1 (float) +)doc"); + +class GetBinaryFeatures : public OpKernel { + public: + explicit GetBinaryFeatures(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + try { + auto handle = getHandle(context, 0); + const auto &records = handle->records; + const auto &common = handle->common; + + int64 common_binary_size = count_if_exists(common.getBinary(), handle->keep_map); + int64 total_binary_size = records.size() * common_binary_size; + for (int id = 0; id < records.size(); id++) { + total_binary_size += count_if_exists(handle->records[id].getBinary(), handle->keep_map); + } + const int total_size = static_cast(total_binary_size); + + TensorShape shape = {total_size}; + Tensor* keys = nullptr; + Tensor* ids = nullptr; + Tensor* values = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &ids)); + OP_REQUIRES_OK(context, context->allocate_output(1, shape, &keys)); + OP_REQUIRES_OK(context, context->allocate_output(2, shape, &values)); + + uint64_t offset = 0; + auto keys_flat = keys->flat(); + auto ids_flat = ids->flat(); + auto values_flat = values->flat(); + + for (int64 id = 0; id < records.size(); id++) { + for (const auto &it : common.getBinary()) { + if (handle->keep_map->find(it) == handle->keep_map->end()) continue; + ids_flat(offset) = id; + keys_flat(offset) = it; + offset++; + } + for (const auto &it : records[id].getBinary()) { + if (handle->keep_map->find(it) == handle->keep_map->end()) continue; + ids_flat(offset) = id; + keys_flat(offset) = it; + offset++; + } + } + // All the values for binary features are 1. + std::fill(values_flat.data(), values_flat.data() + total_size, 1); + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_OP("GetContinuousFeatures") +.Input("data_record_handle: resource") +.Output("ids: int64") +.Output("keys: int64") +.Output("values: float") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that reads continuous features +Input + data_record_handle: Resource handle to DataRecord + +Outputs + ids: ids specifies the index of the records[id] in the batch (int64) + keys: Datarecord keys (int64) + values: Datarecord values(float) +)doc"); + +class GetContinuousFeatures : public OpKernel { + public: + explicit GetContinuousFeatures(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + try { + auto handle = getHandle(context, 0); + const auto &records = handle->records; + const auto &common = handle->common; + + int64 common_continuous_size = count_if_exists(common.getContinuous(), handle->keep_map); + int64 total_continuous_size = records.size() * common_continuous_size; + for (int id = 0; id < records.size(); id++) { + total_continuous_size += count_if_exists(handle->records[id].getContinuous(), + handle->keep_map); + } + const int total_size = static_cast(total_continuous_size); + + TensorShape shape = {total_size}; + Tensor* keys = nullptr; + Tensor* values = nullptr; + Tensor* ids = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &ids)); + OP_REQUIRES_OK(context, context->allocate_output(1, shape, &keys)); + OP_REQUIRES_OK(context, context->allocate_output(2, shape, &values)); + + uint64_t offset = 0; + auto keys_flat = keys->flat(); + auto values_flat = values->flat(); + auto ids_flat = ids->flat(); + + for (int64 id = 0; id < records.size(); id++) { + for (const auto &it : common.getContinuous()) { + if (handle->keep_map->find(it.first) == handle->keep_map->end()) continue; + ids_flat(offset) = id; + keys_flat(offset) = it.first; + values_flat(offset) = it.second; + offset++; + } + for (const auto &it : records[id].getContinuous()) { + if (handle->keep_map->find(it.first) == handle->keep_map->end()) continue; + ids_flat(offset) = id; + keys_flat(offset) = it.first; + values_flat(offset) = it.second; + offset++; + } + } + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_OP("GetDiscreteFeatures") +.Input("data_record_handle: resource") +.Output("ids: int64") +.Output("keys: int64") +.Output("values: int64") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that reads discrete features +Input + data_record_handle: Resource handle to DataRecord + +Outputs + ids: ids specifies the index of the records[id] in the batch (int64) + keys: DataRecord keys (int64) + values: DataRecord values(int64) +)doc"); + +class GetDiscreteFeatures : public OpKernel { + public: + explicit GetDiscreteFeatures(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + try { + auto handle = getHandle(context, 0); + const auto &records = handle->records; + const auto &common = handle->common; + + int64 common_discrete_size = count_if_exists(common.getDiscrete(), handle->keep_map); + int64 total_discrete_size = records.size() * common_discrete_size; + for (int id = 0; id < records.size(); id++) { + total_discrete_size += count_if_exists(handle->records[id].getDiscrete(), + handle->keep_map); + } + const int total_size = static_cast(total_discrete_size); + + TensorShape shape = {total_size}; + Tensor* keys = nullptr; + Tensor* values = nullptr; + Tensor* ids = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &ids)); + OP_REQUIRES_OK(context, context->allocate_output(1, shape, &keys)); + OP_REQUIRES_OK(context, context->allocate_output(2, shape, &values)); + + uint64_t offset = 0; + auto keys_flat = keys->flat(); + auto values_flat = values->flat(); + auto ids_flat = ids->flat(); + + for (int64 id = 0; id < records.size(); id++) { + for (const auto &it : common.getDiscrete()) { + if (handle->keep_map->find(it.first) == handle->keep_map->end()) continue; + ids_flat(offset) = id; + keys_flat(offset) = it.first; + values_flat(offset) = it.second; + offset++; + } + for (const auto &it : records[id].getDiscrete()) { + if (handle->keep_map->find(it.first) == handle->keep_map->end()) continue; + ids_flat(offset) = id; + keys_flat(offset) = it.first; + values_flat(offset) = it.second; + offset++; + } + } + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_OP("GetStringFeatures") +.Input("data_record_handle: resource") +.Output("ids: int64") +.Output("keys: int64") +.Output("names: string") +.Output("values: float") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that reads string features +Input + data_record_handle: Resource handle to DataRecord + +Outputs + ids: ids specifies the index of the records[id] in the batch (int64) + keys: DataRecord keys (int64) + names: DataRecord values(string) + values: always set to 1 (float) +)doc"); + +class GetStringFeatures : public OpKernel { + public: + explicit GetStringFeatures(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + try { + auto handle = getHandle(context, 0); + const auto &records = handle->records; + const auto &common = handle->common; + + int64 common_string_size = count_if_exists(common.getString(), handle->keep_map); + int64 total_string_size = records.size() * common_string_size; + for (int id = 0; id < records.size(); id++) { + total_string_size += count_if_exists(handle->records[id].getString(), + handle->keep_map); + } + const int total_size = static_cast(total_string_size); + + TensorShape shape = {total_size}; + Tensor* keys = nullptr; + Tensor* names = nullptr; + Tensor* ids = nullptr; + Tensor*values = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &ids)); + OP_REQUIRES_OK(context, context->allocate_output(1, shape, &keys)); + OP_REQUIRES_OK(context, context->allocate_output(2, shape, &names)); + OP_REQUIRES_OK(context, context->allocate_output(3, shape, &values)); + + uint64_t offset = 0; + auto keys_flat = keys->flat(); + auto names_flat = names->flat(); + auto ids_flat = ids->flat(); + auto values_flat = values->flat(); + + std::fill(values_flat.data(), values_flat.data() + total_size, 1); + for (int64 id = 0; id < records.size(); id++) { + for (const auto &it : common.getString()) { + if (handle->keep_map->find(it.first) == handle->keep_map->end()) continue; + ids_flat(offset) = id; + keys_flat(offset) = it.first; + names_flat(offset) = it.second; + offset++; + } + for (const auto &it : records[id].getString()) { + if (handle->keep_map->find(it.first) == handle->keep_map->end()) continue; + ids_flat(offset) = id; + keys_flat(offset) = it.first; + names_flat(offset) = it.second; + offset++; + } + } + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_OP("GetSparseBinaryFeatures") +.Input("data_record_handle: resource") +.Output("ids: int64") +.Output("keys: int64") +.Output("names: string") +.Output("values: float") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that reads sparse binary features +Input + data_record_handle: Resource handle to DataRecord + +Outputs + ids: ids specifies the index of the records[id] in the batch (int64) + keys: DataRecord keys (int64) + names: DataRecord values(string) + values: always set to 1 (float) +)doc"); + +class GetSparseBinaryFeatures : public OpKernel { + public: + explicit GetSparseBinaryFeatures(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + try { + auto handle = getHandle(context, 0); + const auto &records = handle->records; + const auto &common = handle->common; + + int64 common_sparse_binary_size = count_if_exists(common.getSparseBinary(), handle->keep_map); + int64 total_sparse_binary_size = records.size() * common_sparse_binary_size; + for (int id = 0; id < records.size(); id++) { + total_sparse_binary_size += count_if_exists(handle->records[id].getSparseBinary(), + handle->keep_map); + } + const int total_size = static_cast(total_sparse_binary_size); + + TensorShape shape = {total_size}; + Tensor* keys = nullptr; + Tensor* names = nullptr; + Tensor* ids = nullptr; + Tensor* values = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &ids)); + OP_REQUIRES_OK(context, context->allocate_output(1, shape, &keys)); + OP_REQUIRES_OK(context, context->allocate_output(2, shape, &names)); + OP_REQUIRES_OK(context, context->allocate_output(3, shape, &values)); + + uint64_t offset = 0; + auto keys_flat = keys->flat(); + auto names_flat = names->flat(); + auto ids_flat = ids->flat(); + auto values_flat = values->flat(); + + // All the values for sparse binary features are 1. + std::fill(values_flat.data(), values_flat.data() + total_size, 1); + for (int64 id = 0; id < records.size(); id++) { + for (const auto &it : common.getSparseBinary()) { + if (handle->keep_map->find(it.first) == handle->keep_map->end()) continue; + for (const auto &it_inner : it.second) { + ids_flat(offset) = id; + keys_flat(offset) = it.first; + names_flat(offset) = it_inner; + offset++; + } + } + for (const auto &it : records[id].getSparseBinary()) { + if (handle->keep_map->find(it.first) == handle->keep_map->end()) continue; + for (const auto &it_inner : it.second) { + ids_flat(offset) = id; + keys_flat(offset) = it.first; + names_flat(offset) = it_inner; + offset++; + } + } + } + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_OP("GetSparseContinuousFeatures") +.Input("data_record_handle: resource") +.Output("ids: int64") +.Output("keys: int64") +.Output("values: float") +.Output("names: string") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that reads sparse continuous features +Input + data_record_handle: Resource handle to DataRecord + +Outputs + ids: ids specifies the index of the records[id] in the batch (int64) + keys: DataRecord keys (int64) + values: DataRecord values(float) + names: DataRecord values(string) +)doc"); + +class GetSparseContinuousFeatures : public OpKernel { + public: + explicit GetSparseContinuousFeatures(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + try { + auto handle = getHandle(context, 0); + const auto &records = handle->records; + const auto &common = handle->common; + + int64 common_sparse_continuous_size = count_if_exists(common.getSparseContinuous(), + handle->keep_map); + int64 total_sparse_continuous_size = records.size() * common_sparse_continuous_size; + for (int id = 0; id < records.size(); id++) { + total_sparse_continuous_size += count_if_exists(handle->records[id].getSparseContinuous(), + handle->keep_map); + } + const int total_size = static_cast(total_sparse_continuous_size); + + TensorShape shape = {total_size}; + Tensor* keys = nullptr; + Tensor* values = nullptr; + Tensor* names = nullptr; + Tensor* ids = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &ids)); + OP_REQUIRES_OK(context, context->allocate_output(1, shape, &keys)); + OP_REQUIRES_OK(context, context->allocate_output(2, shape, &values)); + OP_REQUIRES_OK(context, context->allocate_output(3, shape, &names)); + + uint64_t offset = 0; + auto keys_flat = keys->flat(); + auto values_flat = values->flat(); + auto names_flat = names->flat(); + auto ids_flat = ids->flat(); + + for (int64 id = 0; id < records.size(); id++) { + // copying the contents of the maps of maps + for (const auto &it : common.getSparseContinuous()) { + if (handle->keep_map->find(it.first) == handle->keep_map->end()) continue; + // for each id; iterate through the number of maps corresponding to that id + for (const auto &it_inner : it.second) { + ids_flat(offset) = id; + keys_flat(offset) = it.first; + names_flat(offset) = it_inner.first; + values_flat(offset) = it_inner.second; + offset++; + } + } + // copying the contents of the maps of maps + for (const auto &it : records[id].getSparseContinuous()) { + if (handle->keep_map->find(it.first) == handle->keep_map->end()) continue; + // for each id; iterate through the number of maps corresponding to that id + for (const auto &it_inner : it.second) { + ids_flat(offset) = id; + keys_flat(offset) = it.first; + names_flat(offset) = it_inner.first; + values_flat(offset) = it_inner.second; + offset++; + } + } + } + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_OP("GetBatchSizeFromDataRecord") +.Input("data_record_handle: resource") +.Output("batch_size: int64") +.SetShapeFn(shape_inference::ScalarShape) +.Doc(R"doc( +A tensorflow OP that returns batch size from the data record. +Input + data_record_handle: Resource handle to DataRecord + +Outputs + batch_size: Number of records held in the handle. +)doc"); + +class GetBatchSizeFromDataRecord : public OpKernel { + public: + explicit GetBatchSizeFromDataRecord(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + try { + auto handle = getHandle(context, 0); + Tensor *output; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &output)); + output->scalar()() = handle->records.size(); + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_OP("GetLabelsFromDataRecord") +.Input("data_record_handle: resource") +.Output("labels: float") +.Attr("default_label: float") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that returns labels from the data record. + +Attr + default_label: The value used when a label is absent in a data record. + +Input + data_record_handle: Resource handle to DataRecord + +Outputs + labels: A 2D tensor of size [batch_size, num_labels] containing the label values. +)doc"); + +class GetLabelsFromDataRecord : public OpKernel { + private: + float default_label; + + public: + explicit GetLabelsFromDataRecord(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("default_label", &default_label)); + } + + void Compute(OpKernelContext* context) override { + try { + auto handle = getHandle(context, 0); + const auto &records = handle->records; + const int num_labels = static_cast(handle->num_labels); + TensorShape shape = {static_cast(handle->records.size()), num_labels}; + + Tensor *labels; + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &labels)); + + // The default value of label is not present in data record is std::nanf + // For continuous labels, change that to a default_label or label. + auto func = [this](float label) -> float { + return std::isnan(label) ? default_label : label; + }; + + auto labels_data = labels->flat().data(); + for (const auto &record : records) { + const auto& rec_labels = record.labels(); + labels_data = std::transform(rec_labels.begin(), rec_labels.end(), labels_data, func); + } + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_OP("GetWeightsFromDataRecord") +.Input("data_record_handle: resource") +.Output("weights: float") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that returns weights from the data record. +Input + data_record_handle: Resource handle to DataRecord + +Outputs + weights: A 2D tensor of size [batch_size, num_weights] containing the weight values. +)doc"); + +class GetWeightsFromDataRecord : public OpKernel { + public: + explicit GetWeightsFromDataRecord(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + try { + auto handle = getHandle(context, 0); + const auto &records = handle->records; + const int num_weights = static_cast(handle->num_weights); + TensorShape shape = {static_cast(handle->records.size()), num_weights}; + + Tensor *weights; + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &weights)); + + auto weights_data = weights->flat().data(); + for (const auto &record : records) { + const auto& rec_weights = record.weights(); + weights_data = std::copy(rec_weights.begin(), rec_weights.end(), weights_data); + } + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +template +void SetValueGroup( +const FeatureType& type, +const int64& feature_id, +const int64& id, +const ValueType& default_value, +TensorType values_flat) { + auto it = type.find(feature_id); + values_flat(id) = (it == type.end()) ? default_value : it->second; +} + +template +// overloading for BinaryFeatures; as it needs to set a value of 1 +void SetValueGroup( +const twml::DataRecord::BinaryFeatures& type, +const int64& feature_id, +const int64& id, +const ValueType& default_value, +TensorType values_flat) { + auto it = type.find(feature_id); + values_flat(id) = (it == type.end()) ? default_value : 1; +} + +// Helper for Group Extraction of Dense Features +template +void ComputeHelperGroupFeaturesAsTensors( +OpKernelContext* context, +const std::vector& feature_ids, +ValueType& default_value, +std::function f) { + auto handle = getHandle(context, 0); + const auto &records = handle->records; + // Output shape is 2D; where the first dimension corresponds to the batch_size + // and the second corresponds to the number of features passed to the TF Op. + const int batch_size = static_cast(handle->records.size()); + const int num_feature_ids = static_cast(feature_ids.size()); + TensorShape shape = {batch_size, num_feature_ids}; + + // Define the output + Tensor* values = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &values)); + auto values_flat = values->flat(); + + for (int64 id = 0; id < records.size(); id++) { + const auto &type = f(records[id]); + const auto id_offset = id * feature_ids.size(); + for (int64 fid = 0; fid < feature_ids.size(); fid++) { + auto feature_id = feature_ids[fid]; + // The value is set to default if it does not exist in the current DataRecord + SetValueGroup(type, feature_id, id_offset + fid, default_value, values_flat); + } + } +} + +// Helper for Single Extraction of Dense Features +template +void ComputeHelperFeaturesAsTensors( +OpKernelContext* context, +ValueType& default_value, +int64 feature_id, +std::function f) { + auto handle = getHandle(context, 0); + const auto &records = handle->records; + // Output shape is 2D; where the first dimension corresponds to the batch_size + // and the second corresponds to the number of features passed to the TF Op. + const int total_size = static_cast(handle->records.size()); + TensorShape shape = {total_size}; + + // Define the output + Tensor* values = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &values)); + auto values_flat = values->flat(); + for (int64 id = 0; id < records.size(); id++) { + const auto &type = f(records[id]); + SetValueGroup(type, feature_id, id, default_value, values_flat); + } +} + +REGISTER_OP("GetBinaryAsTensor") +.Input("data_record_handle: resource") +.Attr("feature_id: int") +.Attr("default_value: float") +.Output("values: float") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that returns a Dense Tensor with the values of a particular feature_id. +Input + data_record_handle: Resource handle to DataRecord +Attr + feature_id: Id representing the feature whose values will be extracted. + default_value: default_value to be inputted if the values are missing from the current DataRecord. +Outputs + values: A Tensor corresponding to the value of the feature_id across multiple DataRecords +)doc"); + +class GetBinaryAsTensor : public OpKernel { + private: + int64 feature_id; + float default_value; + + public: + explicit GetBinaryAsTensor(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("feature_id", &feature_id)); + OP_REQUIRES_OK(context, context->GetAttr("default_value", &default_value)); + } + + void Compute(OpKernelContext* context) override { + try { + std::function f = + [](const twml::DataRecord& record) ->const twml::DataRecord::BinaryFeatures& { return record.getBinary(); }; + ComputeHelperFeaturesAsTensors(context, default_value, feature_id, f); + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_OP("GetContinuousAsTensor") +.Input("data_record_handle: resource") +.Attr("feature_id: int") +.Attr("default_value: float") +.Output("values: float") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that returns a Dense Tensor with the values of a particular feature_id. +Input + data_record_handle: Resource handle to DataRecord +Attr + feature_id: Id representing the feature whose values will be extracted. + default_value: default_value to be inputted if the values are missing from the current DataRecord. +Outputs + values: A Tensor corresponding to the value of the feature_id across multiple DataRecords +)doc"); + +class GetContinuousAsTensor : public OpKernel { + private: + int64 feature_id; + float default_value; + + public: + explicit GetContinuousAsTensor(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("feature_id", &feature_id)); + OP_REQUIRES_OK(context, context->GetAttr("default_value", &default_value)); + } + + void Compute(OpKernelContext* context) override { + try { + std::function f = + [](const twml::DataRecord& record) ->const twml::DataRecord::ContinuousFeatures& { return record.getContinuous(); }; + ComputeHelperFeaturesAsTensors(context, default_value, feature_id, f); + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_OP("GetDiscreteAsTensor") +.Input("data_record_handle: resource") +.Attr("feature_id: int") +.Attr("default_value: int") +.Output("values: int64") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that returns a Dense Tensor with the values of a particular feature_id. +Input + data_record_handle: Resource handle to DataRecord +Attr + feature_id: Id representing the feature whose values will be extracted. + default_value: default_value to be inputted if the values are missing from the current DataRecord. +Outputs + values: A Tensor corresponding to the value of the feature_id across multiple DataRecords +)doc"); + +class GetDiscreteAsTensor : public OpKernel { + private: + int64 feature_id; + int64 default_value; + + public: + explicit GetDiscreteAsTensor(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("feature_id", &feature_id)); + OP_REQUIRES_OK(context, context->GetAttr("default_value", &default_value)); + } + + void Compute(OpKernelContext* context) override { + try { + std::function f = + [](const twml::DataRecord& record) ->const twml::DataRecord::DiscreteFeatures& { return record.getDiscrete(); }; + ComputeHelperFeaturesAsTensors(context, default_value, feature_id, f); + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_OP("GetStringAsTensor") +.Input("data_record_handle: resource") +.Attr("feature_id: int") +.Attr("default_value: string") +.Output("names: string") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that returns a Dense Tensor with the values of a particular feature_id. +Input + data_record_handle: Resource handle to DataRecord +Attr + feature_id: Id representing the feature whose values will be extracted. + default_value: default_value to be inputted if the values are missing from the current DataRecord. +Outputs + names: A Tensor corresponding to the value of the feature_id across multiple DataRecords +)doc"); + +class GetStringAsTensor : public OpKernel { + private: + int64 feature_id; + string default_value; + + public: + explicit GetStringAsTensor(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("feature_id", &feature_id)); + OP_REQUIRES_OK(context, context->GetAttr("default_value", &default_value)); + } + + void Compute(OpKernelContext* context) override { + try { + std::function f = + [](const twml::DataRecord& record) ->const twml::DataRecord::StringFeatures& { return record.getString(); }; + ComputeHelperFeaturesAsTensors(context, default_value, feature_id, f); + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + + +REGISTER_OP("GetBinaryGroupAsTensor") +.Input("data_record_handle: resource") +.Attr("feature_ids: list(int)") +.Attr("default_value: float") +.Output("values: float") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that returns a Dense Tensor with the values of a particular feature_id. +Input + data_record_handle: Resource handle to DataRecord +Attr + feature_ids: List of ids representing the features whose values will be extracted. + default_value: default_value to be inputted if the values are missing from the current DataRecord. +Outputs + values: A Tensor corresponding to the values of the feature_ids across multiple DataRecords +)doc"); + + +class GetBinaryGroupAsTensor : public OpKernel { + private: + float default_value; + std::vector feature_ids; + + public: + explicit GetBinaryGroupAsTensor(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("feature_ids", &feature_ids)); + OP_REQUIRES_OK(context, context->GetAttr("default_value", &default_value)); + } + + void Compute(OpKernelContext* context) override { + try { + std::function f = + [](const twml::DataRecord& record) ->const twml::DataRecord::BinaryFeatures& { return record.getBinary(); }; + ComputeHelperGroupFeaturesAsTensors(context, feature_ids, default_value, f); + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + + +REGISTER_OP("GetContinuousGroupAsTensor") +.Input("data_record_handle: resource") +.Attr("feature_ids: list(int)") +.Attr("default_value: float") +.Output("values: float") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that returns a Dense Tensor with the values of a particular feature_id. +Input + data_record_handle: Resource handle to DataRecord +Attr + feature_ids: List of ids representing the features whose values will be extracted. + default_value: default_value to be inputted if the values are missing from the current DataRecord. +Outputs + values: A Tensor corresponding to the values of the feature_ids across multiple DataRecords +)doc"); + +class GetContinuousGroupAsTensor : public OpKernel { + private: + float default_value; + std::vector feature_ids; + + public: + explicit GetContinuousGroupAsTensor(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("feature_ids", &feature_ids)); + OP_REQUIRES_OK(context, context->GetAttr("default_value", &default_value)); + } + + void Compute(OpKernelContext* context) override { + try { + std::function f = + [](const twml::DataRecord& record) ->const twml::DataRecord::ContinuousFeatures& { return record.getContinuous(); }; + ComputeHelperGroupFeaturesAsTensors(context, feature_ids, default_value, f); + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_OP("GetDiscreteGroupAsTensor") +.Input("data_record_handle: resource") +.Attr("feature_ids: list(int)") +.Attr("default_value: int") +.Output("values: int64") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that returns a Dense Tensor with the values of a particular feature_id. +Input + data_record_handle: Resource handle to DataRecord +Attr + feature_ids: List of ids representing the features whose values will be extracted. + default_value: default_value to be inputted if the values are missing from the current DataRecord. +Outputs + values: A Tensor corresponding to the values of the feature_ids across multiple DataRecords +)doc"); + +class GetDiscreteGroupAsTensor : public OpKernel { + private: + std::vector feature_ids; + int64 default_value; + + public: + explicit GetDiscreteGroupAsTensor(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("feature_ids", &feature_ids)); + OP_REQUIRES_OK(context, context->GetAttr("default_value", &default_value)); + } + + void Compute(OpKernelContext* context) override { + try { + std::function f = + [](const twml::DataRecord& record) ->const twml::DataRecord::DiscreteFeatures& { return record.getDiscrete(); }; + ComputeHelperGroupFeaturesAsTensors(context, feature_ids, default_value, f); + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_OP("GetStringGroupAsTensor") +.Input("data_record_handle: resource") +.Attr("feature_ids: list(int)") +.Attr("default_value: string") +.Output("names: string") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that returns a Dense Tensor with the values of a particular feature_id. +Input + data_record_handle: Resource handle to DataRecord +Attr + feature_ids: List of ids representing the features whose values will be extracted. + default_value: default_value to be inputted if the values are missing from the current DataRecord. +Outputs + names: A Tensor corresponding to the values of the feature_ids across multiple DataRecords +)doc"); + +class GetStringGroupAsTensor : public OpKernel { + private: + std::vector feature_ids; + string default_value; + + public: + explicit GetStringGroupAsTensor(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("feature_ids", &feature_ids)); + OP_REQUIRES_OK(context, context->GetAttr("default_value", &default_value)); + } + + void Compute(OpKernelContext* context) override { + try { + std::function f = + [](const twml::DataRecord& record) ->const twml::DataRecord::StringFeatures& { return record.getString(); }; + ComputeHelperGroupFeaturesAsTensors(context, feature_ids, default_value, f); + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_OP("GetSparseBinaryAsTensor") +.Input("data_record_handle: resource") +.Attr("feature_id: int") +.Output("ids: int64") +.Output("keys: int64") +.Output("names: string") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that returns tensors corresponding to the ids, keys and names of a particular +feature_id. +Input + data_record_handle: Resource handle to DataRecord +Attr + feature_id: Id representing the feature whose values will be extracted. +Outputs + ids: ids specifies the index of the records[id] in the batch (int64) + keys: DataRecord keys (int64) + names: DataRecord values(string) +)doc"); +class GetSparseBinaryAsTensor : public OpKernel { + private: + int64 feature_id; + + public: + explicit GetSparseBinaryAsTensor(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("feature_id", &feature_id)); + } + + void Compute(OpKernelContext* context) override { + try { + // We need two passes to the data: + // 1 to compute the output size of the tensor + // 2 to copy the values to the tensor + auto handle = getHandle(context, 0); + const auto &records = handle->records; + + // Creating a vector we increment every time a key is found + std::vector temp_names; + std::vector temp_ids; + + for (int64 id = 0; id < records.size(); id++) { + const auto &sparse_binary = records[id].getSparseBinary(); + auto it = sparse_binary.find(feature_id); + // Find all instances of key in DataRecord + if (it != sparse_binary.end()) { + // insert to temp_names all the values in the dictionary value + temp_names.insert(temp_names.end(), it->second.begin(), it->second.end()); + temp_ids.insert(temp_ids.end(), it->second.size(), id); + } + } + + // The total_size will be the that of the saved vector + const int total_size = static_cast(temp_names.size()); + TensorShape shape = {total_size}; + Tensor* ids = nullptr; + Tensor* keys = nullptr; + Tensor* names = nullptr; + + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &ids)); + OP_REQUIRES_OK(context, context->allocate_output(1, shape, &keys)); + OP_REQUIRES_OK(context, context->allocate_output(2, shape, &names)); + + auto keys_flat = keys->flat(); + auto names_flat = names->flat(); + auto ids_flat = ids->flat(); + + // The feature id value will always be the same + std::fill(keys_flat.data(), keys_flat.data() + total_size, feature_id); + std::copy(temp_names.begin(), temp_names.end(), names_flat.data()); + std::copy(temp_ids.begin(), temp_ids.end(), ids_flat.data()); + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_OP("GetSparseContinuousAsTensor") +.Input("data_record_handle: resource") +.Attr("feature_id: int") +.Output("ids: int64") +.Output("keys: int64") +.Output("names: string") +.Output("values: float") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that returns tensors corresponding to the ids, keys, names and values of a particular +feature_id. +Input + data_record_handle: Resource handle to DataRecord +Attr + feature_id: Id representing the feature whose values will be extracted. +Outputs + ids: ids specifies the index of the records[id] in the batch (int64) + keys: DataRecord keys (int64) + names: DataRecord values(string) + values: DataRecord values(float) +)doc"); +class GetSparseContinuousAsTensor : public OpKernel { + private: + int64 feature_id; + + public: + explicit GetSparseContinuousAsTensor(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("feature_id", &feature_id)); + } + + void Compute(OpKernelContext* context) override { + try { + // We need two passes to the data: + // 1 to compute the output size of the tensor + // 2 to copy the values to the tensor + auto handle = getHandle(context, 0); + const auto &records = handle->records; + + // Creating a vector we increment every time a key is found + std::vector temp_names; + std::vector temp_values; + std::vector temp_ids; + + for (int64 id = 0; id < records.size(); id++) { + const auto &sparse_continuous = records[id].getSparseContinuous(); + auto it = sparse_continuous.find(feature_id); + // Find all instances of key in DataRecord + if (it != sparse_continuous.end()) { + // insert to temp_names all the values in the dictionary value + auto value_map = it->second; + for (auto& elem : value_map) { + temp_names.push_back(elem.first); + temp_values.push_back(elem.second); + temp_ids.push_back(id); + } + } + } + + // The total_size will be the that of the saved vector + const int total_size = static_cast(temp_names.size()); + TensorShape shape = {total_size}; + Tensor* ids = nullptr; + Tensor* keys = nullptr; + Tensor* names = nullptr; + Tensor* values = nullptr; + + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &ids)); + OP_REQUIRES_OK(context, context->allocate_output(1, shape, &keys)); + OP_REQUIRES_OK(context, context->allocate_output(2, shape, &names)); + OP_REQUIRES_OK(context, context->allocate_output(3, shape, &values)); + + auto keys_flat = keys->flat(); + auto names_flat = names->flat(); + auto ids_flat = ids->flat(); + auto values_flat = values->flat(); + + // The feature id value will always be the same + std::fill(keys_flat.data(), keys_flat.data() + total_size, feature_id); + std::copy(temp_names.begin(), temp_names.end(), names_flat.data()); + std::copy(temp_ids.begin(), temp_ids.end(), ids_flat.data()); + std::copy(temp_values.begin(), temp_values.end(), values_flat.data()); + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +// Helper function to add ids, keys and values to common vector +inline void addIdsKeysValuesToVectors( + const int64 id, + const int64 key, + const double value, + std::vector& ids, + std::vector& keys, + std::vector& values) { + ids.push_back(id); + keys.push_back(key); + values.push_back(value); +} + +struct KeepFeatures { + KeepFeatures() : vec(), set() {} + template + KeepFeatures(const std::vector &keep_features, + const ContainerType *const container) { + vec.reserve(keep_features.size()); +#ifdef USE_DENSE_HASH + set.resize(keep_features.size()); + set.set_empty_key(0); +#else + set.reserve(keep_features.size()); +#endif // USE_DENSE_HASH + set.max_load_factor(0.5); + for (const auto &elem : keep_features) { + if (container->find(elem) == container->end()) continue; + vec.push_back(elem); + set.insert(elem); + } + } + size_t size() const { + return vec.size(); + } + std::vector vec; + twml::Set set; +}; + +// Helper Function to Filter and Hash Feature for Binary Features +void filterAndHashFeature( + const twml::DataRecord::BinaryFeatures& features, + const int64 current_id, + const KeepFeatures &keep_features, + std::vector& ids, + std::vector& keys, + std::vector& values) { + if (keep_features.size() < 2 * features.size()) { + for (const auto &f : keep_features.vec) { + const auto &iter = features.find(f); + if (iter == features.end()) continue; + addIdsKeysValuesToVectors(current_id, *iter, 1, ids, keys, values); + } + } else { + for (const auto &elem : features) { + if (keep_features.set.find(elem) == keep_features.set.end()) continue; + addIdsKeysValuesToVectors(current_id, elem, 1, ids, keys, values); + } + } +} + +// Helper Function to Filter and Hash Feature for Continuous Features +void filterAndHashFeature( + const twml::DataRecord::ContinuousFeatures& features, + const int64 current_id, + const KeepFeatures &keep_features, + std::vector& ids, + std::vector& keys, + std::vector& values) { + if (keep_features.size() < 2 * features.size()) { + for (const auto &f : keep_features.vec) { + const auto &iter = features.find(f); + if (iter == features.end()) continue; + addIdsKeysValuesToVectors(current_id, iter->first, iter->second, ids, keys, values); + } + } else { + for (const auto &elem : features) { + if (keep_features.set.find(elem.first) == keep_features.set.end()) continue; + addIdsKeysValuesToVectors(current_id, elem.first, elem.second, ids, keys, values); + } + } +} + +// Helper Function to Filter and Hash Feature for Discrete Features +void filterAndHashFeature( + const twml::DataRecord::DiscreteFeatures& features, + const int64 current_id, + const KeepFeatures &keep_features, + std::vector& ids, + std::vector& keys, + std::vector& values) { + if (keep_features.size() < 2 * features.size()) { + for (const auto &f : keep_features.vec) { + const auto &iter = features.find(f); + if (iter == features.end()) continue; + int64_t key = twml::mixDiscreteIdAndValue(iter->first, iter->second); + addIdsKeysValuesToVectors(current_id, key, 1, ids, keys, values); + } + } else { + for (const auto &elem : features) { + if (keep_features.set.find(elem.first) == keep_features.set.end()) continue; + int64_t key = twml::mixDiscreteIdAndValue(elem.first, elem.second); + addIdsKeysValuesToVectors(current_id, key, 1, ids, keys, values); + } + } +} + +// Helper Function to Filter and Hash Feature for String Features +void filterAndHashFeature( + const twml::DataRecord::StringFeatures& features, + const int64 current_id, + const KeepFeatures &keep_features, + std::vector& ids, + std::vector& keys, + std::vector& values) { + if (keep_features.size() < 2 * features.size()) { + for (const auto &f : keep_features.vec) { + const auto &iter = features.find(f); + if (iter == features.end()) continue; + int64_t key = twml::mixStringIdAndValue( + iter->first, + iter->second.size(), + reinterpret_cast(iter->second.c_str())); + addIdsKeysValuesToVectors(current_id, key, 1, ids, keys, values); + } + } else { + for (const auto &elem : features) { + if (keep_features.set.find(elem.first) == keep_features.set.end()) continue; + int64_t key = twml::mixStringIdAndValue( + elem.first, + elem.second.size(), + reinterpret_cast(elem.second.c_str())); + addIdsKeysValuesToVectors(current_id, key, 1, ids, keys, values); + } + } +} + +// Helper Function to Filter and Hash Feature for Sparse Binary Features +void filterAndHashFeature( + const twml::DataRecord::SparseBinaryFeatures& features, + const int64 current_id, + const KeepFeatures &keep_features, + std::vector& ids, + std::vector& keys, + std::vector& values) { + if (keep_features.size() < 2 * features.size()) { + for (const auto &f : keep_features.vec) { + const auto &iter = features.find(f); + if (iter == features.end()) continue; + for (const auto &name : iter->second) { + int64_t key = twml::mixStringIdAndValue(iter->first, name.size(), + reinterpret_cast(name.c_str())); + addIdsKeysValuesToVectors(current_id, key, 1, ids, keys, values); + } + } + } else { + for (const auto &elem : features) { + if (keep_features.set.find(elem.first) == keep_features.set.end()) continue; + for (const auto &name : elem.second) { + int64_t key = twml::mixStringIdAndValue(elem.first, name.size(), + reinterpret_cast(name.c_str())); + addIdsKeysValuesToVectors(current_id, key, 1, ids, keys, values); + } + } + } +} + +// Helper Function to Filter and Hash Feature for Sparse Continuous Features +void filterAndHashFeature( + const twml::DataRecord::SparseContinuousFeatures& features, + const int64 current_id, + const KeepFeatures &keep_features, + std::vector& ids, + std::vector& keys, + std::vector& values) { + if (keep_features.size() < 2 * features.size()) { + for (const auto &f : keep_features.vec) { + const auto &iter = features.find(f); + if (iter == features.end()) continue; + for (const auto &map : iter->second) { + int64_t key = twml::mixStringIdAndValue( + iter->first, + map.first.size(), + reinterpret_cast(map.first.c_str())); + addIdsKeysValuesToVectors(current_id, key, map.second, ids, keys, values); + } + } + } else { + for (const auto &elem : features) { + if (keep_features.set.find(elem.first) == keep_features.set.end()) continue; + for (const auto &map : elem.second) { + int64_t key = twml::mixStringIdAndValue( + elem.first, + map.first.size(), + reinterpret_cast(map.first.c_str())); + addIdsKeysValuesToVectors(current_id, key, map.second, ids, keys, values); + } + } + } +} + +// Helper Function to Filter and Hash Feature for Sparse Continuous Features +void filterAndHashFeatureCompat( + const twml::DataRecord::SparseContinuousFeatures& features, + const int64 current_id, + const KeepFeatures &keep_features, + std::vector& ids, + std::vector& keys, + std::vector& values) { + if (keep_features.size() < 2 * features.size()) { + for (const auto &f : keep_features.vec) { + const auto &iter = features.find(f); + if (iter == features.end()) continue; + for (const auto &map : iter->second) { + int64_t key = twml::featureId(map.first); + addIdsKeysValuesToVectors(current_id, key, map.second, ids, keys, values); + } + } + } else { + for (const auto &elem : features) { + if (keep_features.set.find(elem.first) == keep_features.set.end()) continue; + for (const auto &map : elem.second) { + int64_t key = twml::featureId(map.first); + addIdsKeysValuesToVectors(current_id, key, map.second, ids, keys, values); + } + } + } +} + +void copy_if_exists(std::vector& out, + const std::vector& in, + const twml::Map *const map) { + out.reserve(in.size()); + for (const auto &elem : in) { + if (map->find(elem) == map->end()) continue; + out.push_back(elem); + } +} + +void ComputeHashedFeaturesAsTensor(OpKernelContext* context, + const DataRecordResource *const handle, + const KeepFeatures &binary_keep_features, + const KeepFeatures &continuous_keep_features, + const KeepFeatures &discrete_keep_features, + const KeepFeatures &string_keep_features, + const KeepFeatures &sparse_binary_keep_features, + const KeepFeatures &sparse_continuous_keep_features, + bool sparse_continuous_compatibility) { + + const auto &records = handle->records; + uint64_t estimated_size = (binary_keep_features.size() + continuous_keep_features.size() + + discrete_keep_features.size() + string_keep_features.size() + + sparse_binary_keep_features.size() + + sparse_continuous_keep_features.size()); + // Construct temporary vectors for common features + std::vector common_ids, common_keys, temp_ids, temp_keys; + std::vector common_values, temp_values; + common_ids.reserve(estimated_size); + common_keys.reserve(estimated_size); + common_values.reserve(estimated_size); + + const auto &common_binary = handle->common.getBinary(); + const auto &common_continuous = handle->common.getContinuous(); + const auto &common_discrete = handle->common.getDiscrete(); + const auto &common_string = handle->common.getString(); + const auto &common_sparse_binary = handle->common.getSparseBinary(); + const auto &common_sparse_continuous = handle->common.getSparseContinuous(); + + filterAndHashFeature(common_binary, 0, binary_keep_features, + common_ids, common_keys, common_values); + filterAndHashFeature(common_continuous, 0, continuous_keep_features, + common_ids, common_keys, common_values); + filterAndHashFeature(common_discrete, 0, discrete_keep_features, + common_ids, common_keys, common_values); + filterAndHashFeature(common_string, 0, string_keep_features, + common_ids, common_keys, common_values); + filterAndHashFeature(common_sparse_binary, 0, sparse_binary_keep_features, + common_ids, common_keys, common_values); + if (sparse_continuous_compatibility) { + filterAndHashFeatureCompat(common_sparse_continuous, 0, sparse_continuous_keep_features, + common_ids, common_keys, common_values); + } else { + filterAndHashFeature(common_sparse_continuous, 0, sparse_continuous_keep_features, + common_ids, common_keys, common_values); + } + common_ids.clear(); + // Construct temporary vectors for all features + estimated_size = (estimated_size + common_keys.size()) * records.size(); + temp_ids.reserve(estimated_size); + temp_keys.reserve(estimated_size); + temp_values.reserve(estimated_size); + + for (int64 id = 0; id < records.size(); id++) { + temp_ids.insert(temp_ids.end(), common_keys.size(), id); + temp_keys.insert(temp_keys.end(), common_keys.begin(), common_keys.end()); + temp_values.insert(temp_values.end(), common_values.begin(), common_values.end()); + const auto &binary = records[id].getBinary(); + const auto &continuous = records[id].getContinuous(); + const auto &discrete = records[id].getDiscrete(); + const auto &str = records[id].getString(); + const auto &sparse_binary = records[id].getSparseBinary(); + const auto &sparse_continuous = records[id].getSparseContinuous(); + + filterAndHashFeature(binary, id, binary_keep_features, + temp_ids, temp_keys, temp_values); + filterAndHashFeature(continuous, id, continuous_keep_features, + temp_ids, temp_keys, temp_values); + filterAndHashFeature(discrete, id, discrete_keep_features, + temp_ids, temp_keys, temp_values); + filterAndHashFeature(str, id, string_keep_features, + temp_ids, temp_keys, temp_values); + filterAndHashFeature(sparse_binary, id, sparse_binary_keep_features, + temp_ids, temp_keys, temp_values); + if (sparse_continuous_compatibility) { + filterAndHashFeatureCompat(sparse_continuous, id, sparse_continuous_keep_features, + temp_ids, temp_keys, temp_values); + } else { + filterAndHashFeature(sparse_continuous, id, sparse_continuous_keep_features, + temp_ids, temp_keys, temp_values); + } + } + + // Copy the temporary vectors into the output Tensors + TensorShape shape = {static_cast(temp_ids.size())}; + Tensor* ids = nullptr; + Tensor* keys = nullptr; + Tensor* values = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &ids)); + OP_REQUIRES_OK(context, context->allocate_output(1, shape, &keys)); + OP_REQUIRES_OK(context, context->allocate_output(2, shape, &values)); + auto ids_flat = ids->flat(); + auto keys_flat = keys->flat(); + auto values_flat = values->flat(); + std::copy(temp_ids.begin(), temp_ids.end(), ids_flat.data()); + std::copy(temp_keys.begin(), temp_keys.end(), keys_flat.data()); + std::copy(temp_values.begin(), temp_values.end(), values_flat.data()); +} + +REGISTER_OP("GetHashedFeaturesAsSparseTensor") +.Input("data_record_handle: resource") +.Attr("binary_keep_features: list(int)") +.Attr("continuous_keep_features: list(int)") +.Attr("discrete_keep_features: list(int)") +.Attr("string_keep_features: list(int)") +.Attr("sparse_binary_keep_features: list(int)") +.Attr("sparse_continuous_keep_features: list(int)") +.Output("ids: int64") +.Output("keys: int64") +.Output("values: float") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); +}).Doc(R"doc( +A tensorflow OP for returning required features of different type as +a single sparse tensor. Hashing trick is applied. + +Input + data_record_handle: Resource handle to DataRecord + +Outputs + ids: ids specifies the index of the records in the batch (int64) + keys: DataRecord keys (int64) + values: DataRecord values (float) +)doc"); + +class GetHashedFeaturesAsSparseTensor: public OpKernel { + public: + explicit GetHashedFeaturesAsSparseTensor(OpKernelConstruction* context): OpKernel(context) { + // Get the list of features to keep for each feature type + OP_REQUIRES_OK(context, context->GetAttr("binary_keep_features", &binary_keep_features_)); + OP_REQUIRES_OK(context, context->GetAttr("continuous_keep_features", &continuous_keep_features_)); + OP_REQUIRES_OK(context, context->GetAttr("discrete_keep_features", &discrete_keep_features_)); + OP_REQUIRES_OK(context, context->GetAttr("string_keep_features", &string_keep_features_)); + OP_REQUIRES_OK(context, context->GetAttr("sparse_binary_keep_features", &sparse_binary_keep_features_)); + OP_REQUIRES_OK(context, context->GetAttr("sparse_continuous_keep_features", &sparse_continuous_keep_features_)); + } + + private: + std::vector binary_keep_features_, continuous_keep_features_, discrete_keep_features_; + std::vector string_keep_features_, sparse_binary_keep_features_, sparse_continuous_keep_features_; + + void Compute(OpKernelContext* context) override { + try { + auto handle = getHandle(context, 0); + // Create a new list of keep features based on the original keep_set. + // This is to ensure compatibility with existing behavior such as: + // - Ensure no new features are decoded in this op. + // - Ensure labels or weights dont get included here. + // TODO: Should we return features requested by user here even if they are labels / weights? + KeepFeatures binary_keep_features(binary_keep_features_, handle->keep_map); + KeepFeatures continuous_keep_features(continuous_keep_features_, handle->keep_map); + KeepFeatures discrete_keep_features(discrete_keep_features_, handle->keep_map); + KeepFeatures string_keep_features(string_keep_features_, handle->keep_map); + KeepFeatures sparse_binary_keep_features(sparse_binary_keep_features_, handle->keep_map); + KeepFeatures sparse_continuous_keep_features(sparse_continuous_keep_features_, handle->keep_map); + ComputeHashedFeaturesAsTensor(context, handle.get(), + binary_keep_features, + continuous_keep_features, + discrete_keep_features, + string_keep_features, + sparse_binary_keep_features, + sparse_continuous_keep_features, + false); + } catch(const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_OP("GetHashedFeaturesAsSparseTensorV2") +.Input("data_record_handle: resource") +.Attr("binary_keep_features: list(int)") +.Attr("continuous_keep_features: list(int)") +.Attr("discrete_keep_features: list(int)") +.Attr("string_keep_features: list(int)") +.Attr("sparse_binary_keep_features: list(int)") +.Attr("sparse_continuous_keep_features: list(int)") +.Attr("keep_features: list(int)") +.Attr("keep_codes: list(int)") +.Attr("decode_mode: int = 0") +.Output("ids: int64") +.Output("keys: int64") +.Output("values: float") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); +}).Doc(R"doc( +A tensorflow OP for returning required features of different type as +a single sparse tensor. Hashing trick is applied. + +Input + data_record_handle: Resource handle to DataRecord + +Outputs + ids: ids specifies the index of the records in the batch (int64) + keys: DataRecord keys (int64) + values: DataRecord values (float) +)doc"); + +class GetHashedFeaturesAsSparseTensorV2: public OpKernel { + public: + explicit GetHashedFeaturesAsSparseTensorV2(OpKernelConstruction* context): OpKernel(context) { + std::vector keep_features; + std::vector keep_codes; + std::vector binary_keep_features_, continuous_keep_features_, discrete_keep_features_; + std::vector string_keep_features_, sparse_binary_keep_features_, sparse_continuous_keep_features_; + + // Get the list of features to keep for each feature type + OP_REQUIRES_OK(context, context->GetAttr("binary_keep_features", &binary_keep_features_)); + OP_REQUIRES_OK(context, context->GetAttr("continuous_keep_features", &continuous_keep_features_)); + OP_REQUIRES_OK(context, context->GetAttr("discrete_keep_features", &discrete_keep_features_)); + OP_REQUIRES_OK(context, context->GetAttr("string_keep_features", &string_keep_features_)); + OP_REQUIRES_OK(context, context->GetAttr("sparse_binary_keep_features", &sparse_binary_keep_features_)); + OP_REQUIRES_OK(context, context->GetAttr("sparse_continuous_keep_features", &sparse_continuous_keep_features_)); + OP_REQUIRES_OK(context, context->GetAttr("keep_features", &keep_features)); + OP_REQUIRES_OK(context, context->GetAttr("keep_codes", &keep_codes)); + OP_REQUIRES_OK(context, context->GetAttr("decode_mode", &m_decode_mode)); + + twml::Map keep_map; +#ifdef USE_DENSE_HASH + keep_map.set_empty_key(0); +#endif // USE_DENSE_HASH + for (uint64_t i = 0; i < keep_features.size(); i++) { + keep_map[keep_features[i]] = keep_codes[i]; + } + + + binary_keep_features = KeepFeatures(binary_keep_features_, &keep_map); + continuous_keep_features = KeepFeatures(continuous_keep_features_, &keep_map); + discrete_keep_features = KeepFeatures(discrete_keep_features_, &keep_map); + string_keep_features = KeepFeatures(string_keep_features_, &keep_map); + sparse_binary_keep_features = KeepFeatures(sparse_binary_keep_features_, &keep_map); + sparse_continuous_keep_features = KeepFeatures(sparse_continuous_keep_features_, &keep_map); + + } + + private: + KeepFeatures binary_keep_features, continuous_keep_features, discrete_keep_features; + KeepFeatures string_keep_features, sparse_binary_keep_features, sparse_continuous_keep_features; + int64 m_decode_mode; + + void Compute(OpKernelContext* context) override { + try { + auto handle = getHandle(context, 0); + // Create a new list of keep features based on the original keep_set. + // This is to ensure compatibility with existing behavior such as: + // - Ensure no new features are decoded in this op. + // - Ensure labels or weights dont get included here. + // TODO: Should we return features requested by user here even if they are labels / weights? + ComputeHashedFeaturesAsTensor(context, handle.get(), + binary_keep_features, + continuous_keep_features, + discrete_keep_features, + string_keep_features, + sparse_binary_keep_features, + sparse_continuous_keep_features, + m_decode_mode == 0); + } catch(const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + + +#define REGISTER_DECODE_DATA_RECORD(InputType) \ + REGISTER_KERNEL_BUILDER( \ + Name("DecodeDataRecord") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("InputType"), \ + DecodeDataRecord); \ + +REGISTER_DECODE_DATA_RECORD(uint8) +REGISTER_DECODE_DATA_RECORD(string) + +#define REGISTER_GETTER(FIELD) \ + REGISTER_KERNEL_BUILDER( \ + Name("Get" #FIELD "Features") \ + .Device(DEVICE_CPU), \ + Get##FIELD##Features); \ + +#define REGISTER_GETTER_FROM_DR(FIELD) \ + REGISTER_KERNEL_BUILDER( \ + Name("Get" #FIELD "FromDataRecord") \ + .Device(DEVICE_CPU), \ + Get##FIELD##FromDataRecord); \ + +#define REGISTER_GETTER_AS_TENSOR(FIELD) \ + REGISTER_KERNEL_BUILDER( \ + Name("Get" #FIELD "AsTensor") \ + .Device(DEVICE_CPU), \ + Get##FIELD##AsTensor); \ + + +#define REGISTER_GETTER_GROUP_AS_TENSOR(FIELD) \ + REGISTER_KERNEL_BUILDER( \ + Name("Get" #FIELD "GroupAsTensor") \ + .Device(DEVICE_CPU), \ + Get##FIELD##GroupAsTensor); \ + +REGISTER_GETTER(Binary) +REGISTER_GETTER(Continuous) +REGISTER_GETTER(Discrete) +REGISTER_GETTER(String) +REGISTER_GETTER(SparseBinary) +REGISTER_GETTER(SparseContinuous) +REGISTER_GETTER_FROM_DR(BatchSize) +REGISTER_GETTER_FROM_DR(Labels) +REGISTER_GETTER_FROM_DR(Weights) +REGISTER_GETTER_AS_TENSOR(Binary) +REGISTER_GETTER_AS_TENSOR(Continuous) +REGISTER_GETTER_AS_TENSOR(Discrete) +REGISTER_GETTER_AS_TENSOR(String) +REGISTER_GETTER_AS_TENSOR(SparseBinary) +REGISTER_GETTER_AS_TENSOR(SparseContinuous) +REGISTER_GETTER_GROUP_AS_TENSOR(Binary) +REGISTER_GETTER_GROUP_AS_TENSOR(Continuous) +REGISTER_GETTER_GROUP_AS_TENSOR(Discrete) +REGISTER_GETTER_GROUP_AS_TENSOR(String) +REGISTER_KERNEL_BUILDER( + Name("GetHashedFeaturesAsSparseTensor") + .Device(DEVICE_CPU), + GetHashedFeaturesAsSparseTensor); +REGISTER_KERNEL_BUILDER( + Name("GetHashedFeaturesAsSparseTensorV2") + .Device(DEVICE_CPU), + GetHashedFeaturesAsSparseTensorV2); diff --git a/twml/libtwml/src/ops/data_record_tensor_writer.cpp b/twml/libtwml/src/ops/data_record_tensor_writer.cpp new file mode 100644 index 000000000..9368c870e --- /dev/null +++ b/twml/libtwml/src/ops/data_record_tensor_writer.cpp @@ -0,0 +1,81 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include +#include "tensorflow_utils.h" + +using namespace tensorflow; + +REGISTER_OP("DataRecordTensorWriter") +.Attr("T: list({string, int32, int64, float, double, bool})") +.Input("keys: int64") +.Input("values: T") +.Output("result: uint8") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( + +A tensorflow OP that packages keys and dense tensors into a DataRecord. + +values: list of tensors +keys: feature ids from the original DataRecord (int64) + +Outputs + bytes: output DataRecord serialized using Thrift into a uint8 tensor. +)doc"); + +class DataRecordTensorWriter : public OpKernel { + public: + explicit DataRecordTensorWriter(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& keys = context->input(0); + + try { + // set keys as twml::Tensor + const twml::Tensor in_keys_ = TFTensor_to_twml_tensor(keys); + + // check sizes + uint64_t num_keys = in_keys_.getNumElements(); + uint64_t num_values = context->num_inputs() - 1; + + OP_REQUIRES(context, num_keys == num_values, + errors::InvalidArgument("Number of dense keys and dense tensors do not match")); + + // populate DataRecord object + const int64_t *keys = in_keys_.getData(); + twml::DataRecord record = twml::DataRecord(); + + for (int i = 1; i < context->num_inputs(); i++) { + const twml::RawTensor& value = TFTensor_to_twml_raw_tensor(context->input(i)); + record.addRawTensor(keys[i-1], value); + } + + // determine the length of the encoded result (no memory is copied) + twml::ThriftWriter thrift_dry_writer = twml::ThriftWriter(nullptr, 0, true); + twml::DataRecordWriter record_dry_writer = twml::DataRecordWriter(thrift_dry_writer); + record_dry_writer.write(record); + int len = thrift_dry_writer.getBytesWritten(); + TensorShape result_shape = {1, len}; + + // allocate output tensor + Tensor* result = NULL; + OP_REQUIRES_OK(context, context->allocate_output(0, result_shape, &result)); + twml::Tensor out_result = TFTensor_to_twml_tensor(*result); + + // write to output tensor + uint8_t *buffer = out_result.getData(); + twml::ThriftWriter thrift_writer = twml::ThriftWriter(buffer, len, false); + twml::DataRecordWriter record_writer = twml::DataRecordWriter(thrift_writer); + record_writer.write(record); + } catch(const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_KERNEL_BUILDER( + Name("DataRecordTensorWriter").Device(DEVICE_CPU), + DataRecordTensorWriter); diff --git a/twml/libtwml/src/ops/discretizer.cpp b/twml/libtwml/src/ops/discretizer.cpp new file mode 100644 index 000000000..10d1b3c78 --- /dev/null +++ b/twml/libtwml/src/ops/discretizer.cpp @@ -0,0 +1,293 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include +#include "tensorflow_utils.h" + +using namespace tensorflow; + + +void ComputeDiscretizers(OpKernelContext* context, const bool return_bin_indices = false) { + const Tensor& keys = context->input(0); + const Tensor& vals = context->input(1); + const Tensor& bin_ids = context->input(2); + const Tensor& bin_vals = context->input(3); + const Tensor& feature_offsets = context->input(4); + + Tensor* new_keys = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, keys.shape(), + &new_keys)); + Tensor* new_vals = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(1, keys.shape(), + &new_vals)); + + try { + twml::Tensor out_keys_ = TFTensor_to_twml_tensor(*new_keys); + twml::Tensor out_vals_ = TFTensor_to_twml_tensor(*new_vals); + + const twml::Tensor in_keys_ = TFTensor_to_twml_tensor(keys); + const twml::Tensor in_vals_ = TFTensor_to_twml_tensor(vals); + const twml::Tensor bin_ids_ = TFTensor_to_twml_tensor(bin_ids); + const twml::Tensor bin_vals_ = TFTensor_to_twml_tensor(bin_vals); + const twml::Tensor feature_offsets_ = TFTensor_to_twml_tensor(feature_offsets); + twml::mdlInfer(out_keys_, out_vals_, + in_keys_, in_vals_, + bin_ids_, bin_vals_, + feature_offsets_, + return_bin_indices); + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } +} + +REGISTER_OP("MDL") +.Attr("T: {float, double}") +.Input("keys: int64") +.Input("vals: T") +.Input("bin_ids: int64") +.Input("bin_vals: T") +.Input("feature_offsets: int64") +.Output("new_keys: int64") +.Output("new_vals: T") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + // TODO: check sizes + c->set_output(0, c->input(0)); + c->set_output(1, c->input(0)); + return Status::OK(); +}).Doc(R"doc( + +This operation discretizes a tensor containing continuous features. + +Input + keys: A tensor containing feature ids. + vals: A tensor containing values at corresponding feature ids. + bin_ids: A tensor containing the discretized feature id for a given bin. + bin_vals: A tensor containing the bin boundaries for value at a given feature id. + feature_offsets: Specifies the starting location of bins for a given feature id. + +Expected Sizes: + keys, vals: [N]. + bin_ids, bin_vals: [sum_{n=1}^{n=num_classes} num_bins(n)] + + where + - N is the number of sparse features in the current batch. + - [0, num_classes) represents the range each feature id can take. + - num_bins(n) is the number of bins for a given feature id. + - If num_bins is fixed, then xs, ys are of size [num_classes * num_bins]. + +Expected Types: + keys, bin_ids: int64. + vals: float or double. + bin_vals: same as vals. + +Before using MDL, you should use a hashmap to get the intersection of +input `keys` with the features that MDL knows about: +:: + keys, vals # keys can be in range [0, 1 << 63) + mdl_keys = hashmap.find(keys) # mdl_keys are now in range [0, num_classes_from_calibration) + mdl_keys = where (mdl_keys != -1) # Ignore keys not found + + +Inside MDL, the following is happening: +:: + start = offsets[key[i]] + end = offsets[key[i] + 1] + idx = binary_search for val[i] in [bin_vals[start], bin_vals[end]] + + result_keys[i] = bin_ids[idx] + val[i] = 1 # binary feature value + +Outputs + new_keys: The discretized feature ids with same shape and size as keys. + new_vals: The discretized values with the same shape and size as vals. + +)doc"); + + +template +class MDL : public OpKernel { + public: + explicit MDL(OpKernelConstruction* context) : OpKernel(context) { + } + + void Compute(OpKernelContext* context) override { + ComputeDiscretizers(context); + } +}; + +REGISTER_OP("PercentileDiscretizer") +.Attr("T: {float, double}") +.Input("keys: int64") +.Input("vals: T") +.Input("bin_ids: int64") +.Input("bin_vals: T") +.Input("feature_offsets: int64") +.Output("new_keys: int64") +.Output("new_vals: T") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + // TODO: check sizes + c->set_output(0, c->input(0)); + c->set_output(1, c->input(0)); + return Status::OK(); +}).Doc(R"doc( + +This operation discretizes a tensor containing continuous features. + +Input + keys: A tensor containing feature ids. + vals: A tensor containing values at corresponding feature ids. + bin_ids: A tensor containing the discretized feature id for a given bin. + bin_vals: A tensor containing the bin boundaries for value at a given feature id. + feature_offsets: Specifies the starting location of bins for a given feature id. + +Expected Sizes: + keys, vals: [N]. + bin_ids, bin_vals: [sum_{n=1}^{n=num_classes} num_bins(n)] + + where + - N is the number of sparse features in the current batch. + - [0, num_classes) represents the range each feature id can take. + - num_bins(n) is the number of bins for a given feature id. + - If num_bins is fixed, then xs, ys are of size [num_classes * num_bins]. + +Expected Types: + keys, bin_ids: int64. + vals: float or double. + bin_vals: same as vals. + +Before using PercentileDiscretizer, you should use a hashmap to get the intersection of +input `keys` with the features that PercentileDiscretizer knows about: +:: + keys, vals # keys can be in range [0, 1 << 63) + percentile_discretizer_keys = hashmap.find(keys) # percentile_discretizer_keys are now in range [0, num_classes_from_calibration) + percentile_discretizer_keys = where (percentile_discretizer_keys != -1) # Ignore keys not found + + +Inside PercentileDiscretizer, the following is happening: +:: + start = offsets[key[i]] + end = offsets[key[i] + 1] + idx = binary_search for val[i] in [bin_vals[start], bin_vals[end]] + + result_keys[i] = bin_ids[idx] + val[i] = 1 # binary feature value + +Outputs + new_keys: The discretized feature ids with same shape and size as keys. + new_vals: The discretized values with the same shape and size as vals. + +)doc"); + +template +class PercentileDiscretizer : public OpKernel { + public: + explicit PercentileDiscretizer(OpKernelConstruction* context) : OpKernel(context) { + } + + void Compute(OpKernelContext* context) override { + ComputeDiscretizers(context); + } +}; + + +REGISTER_OP("PercentileDiscretizerBinIndices") +.Attr("T: {float, double}") +.Input("keys: int64") +.Input("vals: T") +.Input("bin_ids: int64") +.Input("bin_vals: T") +.Input("feature_offsets: int64") +.Output("new_keys: int64") +.Output("new_vals: T") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + // TODO: check sizes + c->set_output(0, c->input(0)); + c->set_output(1, c->input(0)); + return Status::OK(); +}).Doc(R"doc( + +This operation discretizes a tensor containing continuous features. +If the feature id and bin id of the discretized value is the same on multiple runs, they +will always be assigned to the same output key and value, regardless of the bin_id assigned during +calibration. + +Input + keys: A tensor containing feature ids. + vals: A tensor containing values at corresponding feature ids. + bin_ids: A tensor containing the discretized feature id for a given bin. + bin_vals: A tensor containing the bin boundaries for value at a given feature id. + feature_offsets: Specifies the starting location of bins for a given feature id. + +Expected Sizes: + keys, vals: [N]. + bin_ids, bin_vals: [sum_{n=1}^{n=num_classes} num_bins(n)] + + where + - N is the number of sparse features in the current batch. + - [0, num_classes) represents the range each feature id can take. + - num_bins(n) is the number of bins for a given feature id. + - If num_bins is fixed, then xs, ys are of size [num_classes * num_bins]. + +Expected Types: + keys, bin_ids: int64. + vals: float or double. + bin_vals: same as vals. + +Before using PercentileDiscretizerBinIndices, you should use a hashmap to get the intersection of +input `keys` with the features that PercentileDiscretizerBinIndices knows about: +:: + keys, vals # keys can be in range [0, 1 << 63) + percentile_discretizer_keys = hashmap.find(keys) # percentile_discretizer_keys are now in range [0, num_classes_from_calibration) + percentile_discretizer_keys = where (percentile_discretizer_keys != -1) # Ignore keys not found + + +Inside PercentileDiscretizerBinIndices, the following is happening: +:: + start = offsets[key[i]] + end = offsets[key[i] + 1] + idx = binary_search for val[i] in [bin_vals[start], bin_vals[end]] + + result_keys[i] = bin_ids[idx] + val[i] = 1 # binary feature value + +Outputs + new_keys: The discretized feature ids with same shape and size as keys. + new_vals: The discretized values with the same shape and size as vals. + +)doc"); + +template +class PercentileDiscretizerBinIndices : public OpKernel { + public: + explicit PercentileDiscretizerBinIndices(OpKernelConstruction* context) : OpKernel(context) { + } + + void Compute(OpKernelContext* context) override { + ComputeDiscretizers(context, true); + } +}; + + +#define REGISTER(Type) \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("PercentileDiscretizerBinIndices") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + PercentileDiscretizerBinIndices); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("PercentileDiscretizer") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + PercentileDiscretizer); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("MDL") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + MDL); \ + +REGISTER(float); +REGISTER(double); diff --git a/twml/libtwml/src/ops/feature_extractor.cpp b/twml/libtwml/src/ops/feature_extractor.cpp new file mode 100644 index 000000000..9e0910bae --- /dev/null +++ b/twml/libtwml/src/ops/feature_extractor.cpp @@ -0,0 +1,134 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include +#include "tensorflow_utils.h" +#include +#include + +REGISTER_OP("FeatureExtractor") +.Attr("T: {float, double} = DT_FLOAT") +.Input("mask_in: bool") +.Input("ids_in: int64") +.Input("keys_in: int64") +.Input("values_in: T") +.Input("codes_in: int64") +.Input("types_in: int8") +.Output("ids_out: int64") +.Output("keys_out: int64") +.Output("values_out: T") +.Output("codes_out: int64") +.Output("types_out: int8") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( + +A tensorflow OP that extracts the desired indices of a Tensor based on a mask + +Input + mask_in: boolean Tensor that determines which are the indices to be kept (bool) + ids_in: input indices Tensor (int64) + keys_in: input keys Tensor (int64) + values_in: input values Tensor (float/double) + codes_in: input codes Tensor (int64) + types_in: input types Tensor(int8) + +Outputs + ids_out: output indices Tensor (int64) + keys_out: output keys Tensor (int64) + values_out: output values Tensor (float/double) + codes_out: output codes Tensor (int64) + types_out: output types Tensor(int8) + +)doc"); +template +class FeatureExtractor : public OpKernel { + public: + explicit FeatureExtractor(OpKernelConstruction* context) + : OpKernel(context) {} + + template + bool allequal(const A &t, const U &u) { + return t == u; + } + + template + bool allequal(const A &t, const U &u, Others const &... args) { + return (t == u) && allequal(u, args...); + } + + void Compute(OpKernelContext* context) override { + // Get input tensors + const Tensor& input_mask = context->input(0); + const Tensor& input_ids = context->input(1); + const Tensor& input_keys = context->input(2); + const Tensor& input_values = context->input(3); + const Tensor& input_codes = context->input(4); + const Tensor& input_types = context->input(5); + + auto mask = input_mask.flat(); + auto ids = input_ids.flat(); + auto keys = input_keys.flat(); + auto codes = input_codes.flat(); + auto values = input_values.flat(); + auto types = input_types.flat(); + + // Verify that all Tensors have the same size. + OP_REQUIRES(context, allequal(mask.size(), ids.size(), keys.size(), codes.size(), values.size(), types.size()), + errors::InvalidArgument("all input vectors must be the same size.")); + + // Get the size of the output vectors by counting the numbers of trues. + int total_size = 0; + for (int i = 0; i < mask.size(); i++) { + if (mask(i)) + total_size += 1; + } + + // Shape is the number of Trues in the mask Eigen::Tensor + TensorShape shape_out = {total_size}; + + // Create the output tensors + Tensor* output_codes = nullptr; + Tensor* output_ids = nullptr; + Tensor* output_values = nullptr; + Tensor* output_types = nullptr; + Tensor* output_keys = nullptr; + + OP_REQUIRES_OK(context, context->allocate_output(0, shape_out, &output_ids)); + OP_REQUIRES_OK(context, context->allocate_output(1, shape_out, &output_keys)); + OP_REQUIRES_OK(context, context->allocate_output(2, shape_out, &output_values)); + OP_REQUIRES_OK(context, context->allocate_output(3, shape_out, &output_codes)); + OP_REQUIRES_OK(context, context->allocate_output(4, shape_out, &output_types)); + + auto output_ids_ = output_ids->flat(); + auto output_keys_ = output_keys->flat(); + auto output_codes_ = output_codes->flat(); + auto output_values_ = output_values->flat(); + auto output_types_ = output_types->flat(); + + // Iterate through the mask and set values to output Eigen::Tensors + int j = 0; + for (int i = 0; i < mask.size(); i++) { + if (mask(i)) { + output_ids_(j) = ids(i); + output_keys_(j) = keys(i); + output_values_(j) = values(i); + output_codes_(j) = codes(i); + output_types_(j) = types(i); + ++j; + } + } + } +}; + +#define REGISTER(Type) \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("FeatureExtractor") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + FeatureExtractor); \ + +REGISTER(float); +REGISTER(double); diff --git a/twml/libtwml/src/ops/feature_id.cpp b/twml/libtwml/src/ops/feature_id.cpp new file mode 100644 index 000000000..150b5614c --- /dev/null +++ b/twml/libtwml/src/ops/feature_id.cpp @@ -0,0 +1,58 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include +#include "tensorflow_utils.h" + +using namespace tensorflow; + +REGISTER_OP("FeatureId") +.Attr("feature_names: list(string)") +.Output("output: int64") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( + +A tensorflow OP that hashes a list of strings into int64. This is used for feature name hashing. + +Attr + feature_names: a list of string feature names (list(string)). + +Outputs + ouput: hashes corresponding to the string feature names (int64). +)doc"); + + +class FeatureId : public OpKernel { + private: + std::vector input_vector; + + public: + explicit FeatureId(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("feature_names", &input_vector)); + } + + void Compute(OpKernelContext* context) override { + // Get size of the input_vector and create TensorShape shape + const int total_size = static_cast(input_vector.size()); + TensorShape shape = {total_size}; + + // Create an output tensor + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, shape, + &output_tensor)); + auto output_flat = output_tensor->flat(); + + // Transform the input tensor into a int64 + for (int i = 0; i < total_size; i++) { + output_flat(i) = twml::featureId(input_vector[i]); + } + } +}; + + +REGISTER_KERNEL_BUILDER( + Name("FeatureId") + .Device(DEVICE_CPU), + FeatureId); diff --git a/twml/libtwml/src/ops/feature_mask.cpp b/twml/libtwml/src/ops/feature_mask.cpp new file mode 100644 index 000000000..fc1498413 --- /dev/null +++ b/twml/libtwml/src/ops/feature_mask.cpp @@ -0,0 +1,83 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include +#include "tensorflow_utils.h" +#include +#include +#include + +REGISTER_OP("FeatureMask") +.Attr("T: {int64, int8}") +.Input("keep: T") +.Attr("list_keep: list(int)") +.Output("mask: bool") + +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( + +A tensorflow OP that creates a mask of the indices that should be kept. + +Attribute +list_keep: list of values which should be kept(list(int)) + +Input + keep: Tensor for which we will apply the mask (int64, int8) + +Outputs + mask: boolean Tensor. (bool) + +)doc"); +template +class FeatureMask : public OpKernel { + private: + std::set feature_set_keep; + + public: + explicit FeatureMask(OpKernelConstruction* context) + : OpKernel(context) { + std::vector feature_list_keep; + OP_REQUIRES_OK(context, context->GetAttr("list_keep", &feature_list_keep)); + // create set that contains the content of the feature_list_keep, since tensorflow does not allow + // me to directly ouput the contents of list_keep to a set + feature_set_keep = std::set(feature_list_keep.begin(), feature_list_keep.end()); + } + + void Compute(OpKernelContext* context) override { + // Get size of the input_vector and create TensorShape shape + const Tensor& input = context->input(0); + + auto keep = input.flat(); + + // Create an output tensor + Tensor* output_mask = nullptr; + + // Output shape is determined and now we can copy the contents of the vector to the output Tensor. + const int total_size_out = static_cast(keep.size()); + + TensorShape shape_out = {total_size_out}; + + OP_REQUIRES_OK(context, context->allocate_output(0, shape_out, &output_mask)); + + auto output_mask_ = output_mask->flat(); + + // Check if value is in set, output is boolean + for (int j = 0; j < keep.size(); j++){ + output_mask_(j) = (feature_set_keep.count(keep(j))); + } + } +}; + + +#define REGISTER(Type) \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("FeatureMask") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + FeatureMask); \ + +REGISTER(int64); +REGISTER(int8); diff --git a/twml/libtwml/src/ops/fixed_length_tensor.cpp b/twml/libtwml/src/ops/fixed_length_tensor.cpp new file mode 100644 index 000000000..876367ad3 --- /dev/null +++ b/twml/libtwml/src/ops/fixed_length_tensor.cpp @@ -0,0 +1,190 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include +#include "tensorflow_utils.h" +#include "resource_utils.h" + +#include +using std::string; + +template +void ComputeFixedLengthTensor(OpKernelContext *context, int64 max_length_) { + try { + const Tensor& segment_ids = context->input(0); + const Tensor& values = context->input(1); + const Tensor& pad_value = context->input(2); + + auto indices_flat = segment_ids.flat(); + auto values_flat = values.flat(); + + auto pad_value_scalar = pad_value.scalar()(); + + // Get maximum length from batch if user hasn't specified it. + int64 max_length = max_length_; + if (max_length < 0 && indices_flat.size() > 0) { + int64 current_id = indices_flat(0); + int64 current_length = 1; + + for (int64 i = 1; i < indices_flat.size(); i++) { + if (current_id == indices_flat(i)) { + current_length++; + } else { + current_id = indices_flat(i); + max_length = std::max(max_length, current_length); + current_length = 1; + } + } + // This is needed if the last batch is the longest sequence. + max_length = std::max(max_length, current_length); + } + + int64 batch_size = 0; + if (calc_batch_size) { + if (indices_flat.size() > 0) { + // The last value of segment_ids will have value batch_size 1; + batch_size = 1 + indices_flat(indices_flat.size() - 1); + } else { + batch_size = 0; + } + } else { + const Tensor& batch_size_tensor = context->input(3); + batch_size = batch_size_tensor.flat()(0); + } + + TensorShape output_shape = {batch_size, max_length}; + Tensor* fixed_length = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &fixed_length)); + + auto fixed_length_flat = fixed_length->flat(); + + int64 n = 0; + int64 offset = 0; + for (int64 i = 0; i < batch_size; i++) { + for (int64 j = 0; j < max_length; j++) { + if (n < indices_flat.size() && indices_flat(n) == i) { + // Copy from variable length tensor. + fixed_length_flat(offset + j) = values_flat(n); + n++; + } else { + // Pad to fixed length. + fixed_length_flat(offset + j) = pad_value_scalar; + } + } + // Corner case: truncate to max_length if user specified max_length < current length. + while (n < indices_flat.size() && i == indices_flat(n)) n++; + + // Update output pointer + offset += max_length; + } + } catch (const std::exception &err) { + context->CtxFailureWithWarning(errors::InvalidArgument(err.what())); + } +} + +REGISTER_OP("FixedLengthTensor") +.Attr("IndexType: {int64, int32}") +.Attr("ValueType: {int64, int32, string}") +.Attr("max_length: int") +.Input("segment_ids: IndexType") +.Input("values: ValueType") +.Input("pad_value: ValueType") +.Output("fixed_length: ValueType") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( + +A tensorflow OP to convert variable length segments into fixed length tensor. + +Attr + max_length: The size of the inner most (i.e. last) dimension. + +Input + segment_ids: 1D input tensor containing the sorted segment_ids. + values: 1D input tensor containing the values. + pad_value: The value used for padding the fixed length tensor. + +Outputs + fixed_length: A fixed length tensor of size [batch_size, max_length]. +)doc"); + +template +class FixedLengthTensor: public OpKernel { + public: + explicit FixedLengthTensor(OpKernelConstruction *context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("max_length", &max_length_)); + } + + private: + int64 max_length_; + + void Compute(OpKernelContext *context) override { + ComputeFixedLengthTensor(context, max_length_); + } +}; + +REGISTER_OP("FixedLengthTensorV2") +.Attr("IndexType: {int64, int32}") +.Attr("ValueType: {int64, int32, string}") +.Attr("max_length: int") +.Input("segment_ids: IndexType") +.Input("values: ValueType") +.Input("pad_value: ValueType") +.Input("batch_size: int64") +.Output("fixed_length: ValueType") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( + +A tensorflow OP to convert variable length segments into fixed length tensor. + +Attr + max_length: The size of the inner most (i.e. last) dimension. + +Input + segment_ids: 1D input tensor containing the sorted segment_ids. + values: 1D input tensor containing the values. + pad_value: The value used for padding the fixed length tensor. + batch_size: The batch size to use. + +Outputs + fixed_length: A fixed length tensor of size [batch_size, max_length]. +)doc"); + +template +class FixedLengthTensorV2: public OpKernel { + public: + explicit FixedLengthTensorV2(OpKernelConstruction *context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("max_length", &max_length_)); + } + + private: + int64 max_length_; + + void Compute(OpKernelContext *context) override { + ComputeFixedLengthTensor(context, max_length_); + } +}; + +#define REGISTER_SPARSE_TO_FIXED_LENGTH(IndexType, ValueType) \ + REGISTER_KERNEL_BUILDER( \ + Name("FixedLengthTensor") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("IndexType") \ + .TypeConstraint("ValueType"), \ + FixedLengthTensor); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("FixedLengthTensorV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("IndexType") \ + .TypeConstraint("ValueType"), \ + FixedLengthTensorV2); \ + +REGISTER_SPARSE_TO_FIXED_LENGTH(int64, int64) +REGISTER_SPARSE_TO_FIXED_LENGTH(int64, int32) +REGISTER_SPARSE_TO_FIXED_LENGTH(int64, string) +REGISTER_SPARSE_TO_FIXED_LENGTH(int32, int64) +REGISTER_SPARSE_TO_FIXED_LENGTH(int32, int32) +REGISTER_SPARSE_TO_FIXED_LENGTH(int32, string) diff --git a/twml/libtwml/src/ops/hashed_data_record.cpp b/twml/libtwml/src/ops/hashed_data_record.cpp new file mode 100644 index 000000000..ba094c3d9 --- /dev/null +++ b/twml/libtwml/src/ops/hashed_data_record.cpp @@ -0,0 +1,520 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include +#include "tensorflow_utils.h" +#include "resource_utils.h" + +#include + +REGISTER_OP("DecodeAndHashDataRecord") +.Attr("InputType: {uint8, string}") +.Input("input_bytes: InputType") +.Attr("keep_features: list(int)") +.Attr("keep_codes: list(int)") +.Attr("label_features: list(int)") +.Attr("weight_features: list(int) = []") +.Attr("decode_mode: int = 0") +.Output("hashed_data_record_handle: resource") +.SetShapeFn(shape_inference::ScalarShape) +.Doc(R"doc( +A tensorflow OP that creates a handle for the hashed data record. + +Attr + keep_features: a list of int ids to keep. + keep_codes: their corresponding code. + label_features: list of feature ids representing the labels. + weight_features: list of feature ids representing the weights. Defaults to empty list. + decode_mode: integer, indicates which decoding method to use. Let a sparse continuous + have a feature_name and a dict of {name: value}. 0 indicates feature_ids are computed + as hash(name). 1 indicates feature_ids are computed as hash(feature_name, name) + shared_name: name used by the resource handle inside the resource manager. + container: name used by the container of the resources. + +Input + input_bytes: Input tensor containing the serialized batch of HashedDataRecords. + +Outputs + hashed_data_record_handle: A resource handle to batch of HashedDataRecords. +)doc"); + +template +class DecodeAndHashDataRecord : public OpKernel { + public: + explicit DecodeAndHashDataRecord(OpKernelConstruction* context) + : OpKernel(context) { + std::vector keep_features; + std::vector keep_codes; + + std::vector label_features; + std::vector weight_features; + + OP_REQUIRES_OK(context, context->GetAttr("keep_features", &keep_features)); + OP_REQUIRES_OK(context, context->GetAttr("keep_codes", &keep_codes)); + OP_REQUIRES_OK(context, context->GetAttr("label_features", &label_features)); + OP_REQUIRES_OK(context, context->GetAttr("weight_features", &weight_features)); + OP_REQUIRES_OK(context, context->GetAttr("decode_mode", &m_decode_mode)); + + OP_REQUIRES(context, keep_features.size() == keep_codes.size(), + errors::InvalidArgument("keep keys and values must have same size.")); + +#ifdef USE_DENSE_HASH + m_keep_map.set_empty_key(0); + m_labels_map.set_empty_key(0); + m_weights_map.set_empty_key(0); +#endif // USE_DENSE_HASH + + for (uint64_t i = 0; i < keep_features.size(); i++) { + m_keep_map[keep_features[i]] = keep_codes[i]; + } + + for (uint64_t i = 0; i < label_features.size(); i++) { + m_labels_map[label_features[i]] = i; + } + + for (uint64_t i = 0; i < weight_features.size(); i++) { + m_weights_map[weight_features[i]] = i; + } + } + + private: + twml::Map m_keep_map; + twml::Map m_labels_map; + twml::Map m_weights_map; + int64 m_decode_mode; + + void Compute(OpKernelContext* context) override { + try { + HashedDataRecordResource *resource = nullptr; + OP_REQUIRES_OK(context, makeResourceHandle(context, 0, &resource)); + + // Store the input bytes in the resource so it isnt freed before the resource. + // This is necessary because we are not copying the contents for tensors. + resource->input = context->input(0); + int batch_size = getBatchSize(resource->input); + int num_labels = static_cast(m_labels_map.size()); + int num_weights = static_cast(m_weights_map.size()); + + twml::HashedDataRecordReader reader; + reader.setKeepMap(&m_keep_map); + reader.setLabelsMap(&m_labels_map); + reader.setDecodeMode(m_decode_mode); + + // Do not set weight map if it is empty. This will take a faster path. + if (num_weights != 0) { + reader.setWeightsMap(&m_weights_map); + } + + resource->records.clear(); + resource->records.reserve(batch_size); + + int64 total_size = 0; + + for (int id = 0; id < batch_size; id++) { + const uint8_t *input_bytes = getInputBytes(resource->input, id); + reader.setBuffer(input_bytes); + resource->records.emplace_back(num_labels, num_weights); + resource->records[id].decode(reader); + total_size += static_cast(resource->records[id].totalSize()); + } + + resource->total_size = total_size; + resource->num_labels = num_labels; + resource->num_weights = num_weights; + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_OP("GetIdsFromHashedDataRecord") +.Input("hashed_data_record_handle: resource") +.Output("ids: int64") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that returns unhashed ids from the hashed data record. +Input + hashed_data_record_handle: Resource handle to DataRecord + +Outputs + ids: ids specifies the index of the records[id] in the batch (int64) +)doc"); + +// This Kernel is used for both training and serving once the resource is created. +class GetIdsFromHashedDataRecord : public OpKernel { + public: + explicit GetIdsFromHashedDataRecord(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + try { + auto handle = getHandle(context, 0); + const auto &records = handle->records; + const auto &common = handle->common; + const int64 common_size = static_cast(common.totalSize()); + const int64 total_size = handle->total_size; + TensorShape shape = {total_size}; + + Tensor *ids; + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &ids)); + + int id = 0; + int64 offset = 0; + auto ids_flat = ids->flat(); + for (const auto &record : records) { + // Since common features are added to each input, add the common_size to the current size. + // For training common_size == 0, for serving it can be a non-zero value. + int64 curr_size = static_cast(record.totalSize()) + common_size; + std::fill(ids_flat.data() + offset, ids_flat.data() + offset + curr_size, id); + offset += curr_size; + id++; + } + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + + +// OutType: Output Tensor Type. FieldType: The storage type used inside HashedDatarecord. +template +class GetOutputFromHashedDataRecord : public OpKernel { + protected: + using Getter = std::function&(const twml::HashedDataRecord &)>; + Getter getter; + + public: + explicit GetOutputFromHashedDataRecord(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + try { + auto handle = getHandle(context, 0); + const auto &records = handle->records; + const auto &common = handle->common; + const int64 total_size = handle->total_size; + TensorShape shape = {total_size}; + + Tensor *output; + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output)); + + const auto &common_output = getter(common); + + auto output_data = output->flat().data(); + for (const auto &record : records) { + // This is does not copy anything during training as common_size == 0 + // It will copy the relevant common features coming from a batch prediction request. + output_data = std::copy(common_output.begin(), common_output.end(), output_data); + + // Copy the current record to output. + const auto& rec_output = getter(record); + output_data = std::copy(rec_output.begin(), rec_output.end(), output_data); + } + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_OP("GetUKeysFromHashedDataRecord") +.Input("hashed_data_record_handle: resource") +.Output("ukeys: int64") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that returns unhashed keys from the hashed data record. +Input + hashed_data_record_handle: Resource handle to DataRecord + +Outputs + ukeys: unhased keys / raw feature ids from the original request. +)doc"); + +class GetUKeysFromHashedDataRecord : public GetOutputFromHashedDataRecord { + public: + explicit GetUKeysFromHashedDataRecord(OpKernelConstruction* context) + : GetOutputFromHashedDataRecord(context){ + getter = [](const twml::HashedDataRecord &record) -> const std::vector & { + return record.keys(); + }; + } +}; + +REGISTER_OP("GetKeysFromHashedDataRecord") +.Input("hashed_data_record_handle: resource") +.Output("keys: int64") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that returns keys from the hashed data record. +Input + hashed_data_record_handle: Resource handle to DataRecord + +Outputs + keys: keys after raw feature ids are hashed with values (int64) +)doc"); + +class GetKeysFromHashedDataRecord : public GetOutputFromHashedDataRecord { + public: + explicit GetKeysFromHashedDataRecord(OpKernelConstruction* context) + : GetOutputFromHashedDataRecord(context){ + getter = [](const twml::HashedDataRecord &record) -> const std::vector & { + return record.transformed_keys(); + }; + } +}; + +REGISTER_OP("GetValuesFromHashedDataRecord") +.Input("hashed_data_record_handle: resource") +.Output("values: float") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that returns values from the hashed data record. +Input + hashed_data_record_handle: Resource handle to DataRecord + +Outputs + values: feature values. +)doc"); + +class GetValuesFromHashedDataRecord : public GetOutputFromHashedDataRecord { + public: + explicit GetValuesFromHashedDataRecord(OpKernelConstruction* context) + : GetOutputFromHashedDataRecord(context){ + getter = [](const twml::HashedDataRecord &record) -> const std::vector & { + return record.values(); + }; + } +}; + +REGISTER_OP("GetCodesFromHashedDataRecord") +.Input("hashed_data_record_handle: resource") +.Output("codes: int64") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that returns codes from the hashed data record. +Input + hashed_data_record_handle: Resource handle to DataRecord + +Outputs + codes: deepbird feature code, usually from A,B,C,D ... in the config. +)doc"); + +class GetCodesFromHashedDataRecord : public GetOutputFromHashedDataRecord { + public: + explicit GetCodesFromHashedDataRecord(OpKernelConstruction* context) + : GetOutputFromHashedDataRecord(context){ + getter = [](const twml::HashedDataRecord &record) -> const std::vector & { + return record.codes(); + }; + } +}; + +REGISTER_OP("GetTypesFromHashedDataRecord") +.Input("hashed_data_record_handle: resource") +.Output("types: int8") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that returns types from the hashed data record. +Input + hashed_data_record_handle: Resource handle to DataRecord + +Outputs + types: feature types corresponding to BINARY, DISCRETE, etc. +)doc"); + +class GetTypesFromHashedDataRecord : public GetOutputFromHashedDataRecord { + public: + explicit GetTypesFromHashedDataRecord(OpKernelConstruction* context) + : GetOutputFromHashedDataRecord(context){ + getter = [](const twml::HashedDataRecord &record) -> const std::vector & { + return record.types(); + }; + } +}; + +REGISTER_OP("GetBatchSizeFromHashedDataRecord") +.Input("hashed_data_record_handle: resource") +.Output("batch_size: int64") +.SetShapeFn(shape_inference::ScalarShape) +.Doc(R"doc( +A tensorflow OP that returns batch size from the hashed data record. +Input + hashed_data_record_handle: Resource handle to DataRecord + +Outputs + batch_size: Number of records held in the handle. +)doc"); + +class GetBatchSizeFromHashedDataRecord : public OpKernel { + public: + explicit GetBatchSizeFromHashedDataRecord(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + try { + auto handle = getHandle(context, 0); + Tensor *output; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &output)); + output->scalar()() = handle->records.size(); + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_OP("GetTotalSizeFromHashedDataRecord") +.Input("hashed_data_record_handle: resource") +.Output("total_size: int64") +.SetShapeFn(shape_inference::ScalarShape) +.Doc(R"doc( +A tensorflow OP that returns total size from the hashed data record. +Input + hashed_data_record_handle: Resource handle to DataRecord + +Outputs + total_size: Total number of keys / values in the batch. +)doc"); + +class GetTotalSizeFromHashedDataRecord : public OpKernel { + public: + explicit GetTotalSizeFromHashedDataRecord(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + try { + auto handle = getHandle(context, 0); + + Tensor *output; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &output)); + output->scalar()() = handle->total_size; + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_OP("GetLabelsFromHashedDataRecord") +.Input("hashed_data_record_handle: resource") +.Output("labels: float") +.Attr("default_label: float") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that returns labels from the hashed data record. +Input + hashed_data_record_handle: Resource handle to DataRecord + +Outputs + labels: A 2D tensor of size [batch_size, num_labels] containing the label values. +)doc"); + +class GetLabelsFromHashedDataRecord : public OpKernel { + private: + float default_label; + + public: + explicit GetLabelsFromHashedDataRecord(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("default_label", &default_label)); + } + + void Compute(OpKernelContext* context) override { + try { + auto handle = getHandle(context, 0); + const auto &records = handle->records; + const int num_labels = static_cast(handle->num_labels); + TensorShape shape = {static_cast(handle->records.size()), num_labels}; + + Tensor *labels; + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &labels)); + + // The default value of label is not present in data record is std::nanf + // For continuous labels, change that to a default_label or label. + auto func = [this](float label) -> float { + return std::isnan(label) ? default_label : label; + }; + + auto labels_data = labels->flat().data(); + for (const auto &record : records) { + const auto& rec_labels = record.labels(); + labels_data = std::transform(rec_labels.begin(), rec_labels.end(), labels_data, func); + } + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_OP("GetWeightsFromHashedDataRecord") +.Input("hashed_data_record_handle: resource") +.Output("weights: float") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that returns weights from the hashed data record. +Input + hashed_data_record_handle: Resource handle to DataRecord + +Outputs + weights: A 2D tensor of size [batch_size, num_weights] containing the weight values. +)doc"); + +class GetWeightsFromHashedDataRecord : public OpKernel { + public: + explicit GetWeightsFromHashedDataRecord(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + try { + auto handle = getHandle(context, 0); + const auto &records = handle->records; + const int num_weights = static_cast(handle->num_weights); + TensorShape shape = {static_cast(handle->records.size()), num_weights}; + + Tensor *weights; + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &weights)); + + auto weights_data = weights->flat().data(); + for (const auto &record : records) { + const auto& rec_weights = record.weights(); + weights_data = std::copy(rec_weights.begin(), rec_weights.end(), weights_data); + } + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + + +#define REGISTER_DECODE_AND_HASH(InputType) \ + REGISTER_KERNEL_BUILDER( \ + Name("DecodeAndHashDataRecord") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("InputType"), \ + DecodeAndHashDataRecord); \ + +REGISTER_DECODE_AND_HASH(uint8) +REGISTER_DECODE_AND_HASH(string) + +#define REGISTER_GETTER(FIELD) \ + REGISTER_KERNEL_BUILDER( \ + Name("Get" #FIELD "FromHashedDataRecord") \ + .Device(DEVICE_CPU), \ + Get##FIELD##FromHashedDataRecord); \ + +REGISTER_GETTER(Ids) +REGISTER_GETTER(UKeys) +REGISTER_GETTER(Keys) +REGISTER_GETTER(Values) +REGISTER_GETTER(Codes) +REGISTER_GETTER(Types) +REGISTER_GETTER(BatchSize) +REGISTER_GETTER(TotalSize) +REGISTER_GETTER(Labels) +REGISTER_GETTER(Weights) diff --git a/twml/libtwml/src/ops/hashing_discretizer.cpp b/twml/libtwml/src/ops/hashing_discretizer.cpp new file mode 100644 index 000000000..634f6db33 --- /dev/null +++ b/twml/libtwml/src/ops/hashing_discretizer.cpp @@ -0,0 +1,260 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/util/work_sharder.h" + +#include +#include "tensorflow_utils.h" + +using namespace tensorflow; + +void ComputeHashingDiscretizer( + OpKernelContext*, + int64_t, + const twml::Map &, + int64_t, + int64_t, + int64_t); + +REGISTER_OP("HashingDiscretizer") +.Attr("T: {float, double}") +.Input("input_ids: int64") +.Input("input_vals: T") +.Input("bin_vals: T") +.Attr("feature_ids: tensor = { dtype: DT_INT64 }") +.Attr("n_bin: int") +.Attr("output_bits: int") +.Attr("cost_per_unit: int") +.Attr("options: int") +.Output("new_keys: int64") +.Output("new_vals: T") +.SetShapeFn( + [](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + c->set_output(1, c->input(1)); + return Status::OK(); + } +) +.Doc(R"doc( + +This operation discretizes a tensor containing continuous features (if calibrated). + - note - choice of float or double should be consistent among inputs/output + +Input + input_ids(int64): A tensor containing input feature ids (direct from data record). + input_vals(float/double): A tensor containing input values at corresponding feature ids. + - i.e. input_ids[i] <-> input_vals[i] for each i + bin_vals(float/double): A tensor containing the bin boundaries for values of a given feature. + - float or double, matching input_vals + feature_ids(int64 attr): 1D TensorProto of feature IDs seen during calibration + -> hint: look up make_tensor_proto: + proto_init = np.array(values, dtype=np.int64) + tensor_attr = tf.make_tensor_proto(proto_init) + n_bin(int): The number of bin boundary values per feature + -> hence, n_bin + 1 buckets for each feature + output_bits(int): The maximum number of bits to use for the output IDs. + cost_per_unit(int): An estimate of the number of CPU cycles (or nanoseconds + if not CPU-bound) to complete a unit of work. Overestimating creates too + many shards and CPU time will be dominated by per-shard overhead, such as + Context creation. Underestimating may not fully make use of the specified + parallelism. + options(int): selects behavior of the op. + 0x00 in bits{1:0} for std::lower_bound bucket search. + 0x01 in bits{1:0} for linear bucket search + 0x02 in bits{1:0} for std::upper_bound bucket search + 0x00 in bits{4:2} for integer_multiplicative_hashing + 0x01 in bits{4:2} for integer64_multiplicative_hashing + higher bits/other values are reserved for future extensions + +Outputs + new_keys(int64): The discretized feature ids with same shape and size as keys. + new_vals(float or double): The discretized values with the same shape and size as vals. + +Operation + Note that the discretization operation maps observation vectors to higher dimensional + observation vectors. Here, we describe this mapping. + + Let a calibrated feature observation be given by (F,x), where F is the ID of the + feature, and x is some real value (i.e., continuous feature). This kind of + representation is useful for the representation of sparse vectors, where there + are many zeros. + + For example, for a dense feature vector [1.2, 2.4, 3.6], we might have + (0, 1.2) (1, 2.4) and (2, 3.6), with feature IDs indicating the 0th, 1st, and 2nd + elements of the vector. + + The disretizer performs the following operation: + (F,x) -> (map(x|F),1). + Hence, we have that map(x|F) is a new feature ID, and the value observed for that + feature is 1. We might read map(x|F) as 'the map of x for feature F'. + + For each feature F, we associate a (discrete, finite) set of new feature IDs, newIDs(F). + We will then have that map(x|F) is in the set newIDs(F) for any value of x. Each + set member of newIDs(F) is associated with a 'bin', as defined by the bin + boundaries given in the bin_vals input array. For any two different feature IDs F + and G, we would ideally have that INTERSECT(newIDs(F),newIDs(G)) is the empty set. + However, this is not guaranteed for this discretizer. + + In the case of this hashing discretizer, map(x|F) can actually be written as follows: + let bucket = bucket(x|F) be the the bucket index for x, according to the + calibration on F. (This is an integer value in [0,n_bin], inclusive) + F is an integer ID. Here, we have that map(x|F) = hash_fn(F,bucket). This has + the desirable property that the new ID depends only on the calibration data + supplied for feature F, and not on any other features in the dataset (e.g., + number of other features present in the calibration data, or order of features + in the dataset). Note that PercentileDiscretizer does NOT have this property. + This comes at the expense of the possibility of output ID collisions, which + we try to minimize through the design of hash_fn. + + Example - consider input vector with a single element, i.e. [x]. + Let's Discretize to one of 2 values, as follows: + Let F=0 for the ID of the single feature in the vector. + Let the bin boundary of feature F=0 be BNDRY(F) = BNDRY(0) since F=0 + bucket = bucket(x|F=0) = 0 if x<=BNDRY(0) else 1 + Let map(x|F) = hash_fn(F=0,bucket=0) if x<=BNDRY(0) else hash_fn(F=0,bucket=1) + If we had another element y in the vector, i.e. [x, y], then we might additionally + Let F=1 for element y. + Let the bin boundary be BNDRY(F) = BNDRY(1) since F=1 + bucket = bucket(x|F=1) = 0 if x<=BNDRY(1) else 1 + Let map(x|F) = hash_fn(F=1,bucket=0) if x<=BNDRY(1) else hash_fn(F=1,bucket=1) + Note how the construction of map(x|F=1) does not depend on whether map(x|F=0) + was constructed. +)doc"); + +template +class HashingDiscretizer : public OpKernel { + public: + explicit HashingDiscretizer(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("n_bin", &n_bin_)); + OP_REQUIRES(context, + n_bin_ > 0, + errors::InvalidArgument("Must have n_bin_ > 0.")); + + OP_REQUIRES_OK(context, + context->GetAttr("output_bits", &output_bits_)); + OP_REQUIRES(context, + output_bits_ > 0, + errors::InvalidArgument("Must have output_bits_ > 0.")); + + OP_REQUIRES_OK(context, + context->GetAttr("cost_per_unit", &cost_per_unit_)); + OP_REQUIRES(context, + cost_per_unit_ >= 0, + errors::InvalidArgument("Must have cost_per_unit >= 0.")); + + OP_REQUIRES_OK(context, + context->GetAttr("options", &options_)); + + // construct the ID_to_index hash map + Tensor feature_IDs; + + // extract the tensors + OP_REQUIRES_OK(context, + context->GetAttr("feature_ids", &feature_IDs)); + + // for access to the data + // int64_t data type is set in to_layer function of the calibrator objects in Python + auto feature_IDs_flat = feature_IDs.flat(); + + // verify proper dimension constraints + OP_REQUIRES(context, + feature_IDs.shape().dims() == 1, + errors::InvalidArgument("feature_ids must be 1D.")); + + // reserve space in the hash map and fill in the values + int64_t num_features = feature_IDs.shape().dim_size(0); +#ifdef USE_DENSE_HASH + ID_to_index_.set_empty_key(0); + ID_to_index_.resize(num_features); +#else + ID_to_index_.reserve(num_features); +#endif // USE_DENSE_HASH + for (int64_t i = 0 ; i < num_features ; i++) { + ID_to_index_[feature_IDs_flat(i)] = i; + } + } + + void Compute(OpKernelContext* context) override { + ComputeHashingDiscretizer( + context, + output_bits_, + ID_to_index_, + n_bin_, + cost_per_unit_, + options_); + } + + private: + twml::Map ID_to_index_; + int n_bin_; + int output_bits_; + int cost_per_unit_; + int options_; +}; + +#define REGISTER(Type) \ + REGISTER_KERNEL_BUILDER( \ + Name("HashingDiscretizer") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + HashingDiscretizer); \ + +REGISTER(float); +REGISTER(double); + +void ComputeHashingDiscretizer( + OpKernelContext* context, + int64_t output_bits, + const twml::Map &ID_to_index, + int64_t n_bin, + int64_t cost_per_unit, + int64_t options) { + const Tensor& keys = context->input(0); + const Tensor& vals = context->input(1); + const Tensor& bin_vals = context->input(2); + + const int64 output_size = keys.dim_size(0); + + TensorShape output_shape; + OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(&output_size, 1, &output_shape)); + + Tensor* new_keys = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &new_keys)); + Tensor* new_vals = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(1, output_shape, &new_vals)); + + try { + twml::Tensor out_keys_ = TFTensor_to_twml_tensor(*new_keys); + twml::Tensor out_vals_ = TFTensor_to_twml_tensor(*new_vals); + + const twml::Tensor in_keys_ = TFTensor_to_twml_tensor(keys); + const twml::Tensor in_vals_ = TFTensor_to_twml_tensor(vals); + const twml::Tensor bin_vals_ = TFTensor_to_twml_tensor(bin_vals); + + // retrieve the thread pool from the op context + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + + // Definition of the computation thread + auto task = [&](int64 start, int64 limit) { + twml::hashDiscretizerInfer(out_keys_, out_vals_, + in_keys_, in_vals_, + n_bin, + bin_vals_, + output_bits, + ID_to_index, + start, limit, + options); + }; + + // let Tensorflow split up the work as it sees fit + Shard(worker_threads.num_threads, + worker_threads.workers, + output_size, + static_cast(cost_per_unit), + task); + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } +} + diff --git a/twml/libtwml/src/ops/hashmap.cpp b/twml/libtwml/src/ops/hashmap.cpp new file mode 100644 index 000000000..ce11ff81d --- /dev/null +++ b/twml/libtwml/src/ops/hashmap.cpp @@ -0,0 +1,84 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include + +#include + +using namespace tensorflow; + +REGISTER_OP("Hashmap") +.Input("keys: int64") +.Input("hash_keys: int64") +.Input("hash_values: int64") +.Output("values: int64") +.Output("mask: int8") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + // TODO: check if the sizes are different in the input + c->set_output(0, c->input(0)); + c->set_output(1, c->input(0)); + return Status::OK(); + }); + + +class Hashmap : public OpKernel { + private: + twml::HashMap hmap; + std::once_flag flag; + + public: + explicit Hashmap(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + try { + // Quick hack + const Tensor& keys = context->input(0); + + std::call_once(this->flag, [this, context](){ + const Tensor& hash_keys = context->input(1); + const Tensor& hash_values = context->input(2); + const auto hash_keys_flat = hash_keys.flat(); + const auto hash_values_flat = hash_values.flat(); + const int64 N = hash_keys_flat.size(); + + for (int64 i = 0; i < N; i++) { + hmap.insert(hash_keys_flat(i), hash_values_flat(i)); + } + }); + + Tensor* values = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, keys.shape(), + &values)); + + Tensor* mask = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(1, keys.shape(), + &mask)); + + // copy the values without sharing a storage + values->flat() = keys.flat(); + + auto keys_flat = keys.flat(); + auto values_flat = values->flat(); + auto mask_flat = mask->flat(); + + // TODO: use twml tensor + const int64 N = keys_flat.size(); + for (int64 i = 0; i < N; i++) { + // values_flat(i), keys_flat(i) return references to tensorflow::int64. + // Using them in hmap.get() was causing issues because of automatic casting. + int64_t val = values_flat(i); + int64_t key = keys_flat(i); + mask_flat(i) = hmap.get(val, key); + values_flat(i) = val; + } + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +REGISTER_KERNEL_BUILDER( + Name("Hashmap") + .Device(DEVICE_CPU), + Hashmap); diff --git a/twml/libtwml/src/ops/isotonic_calibration.cpp b/twml/libtwml/src/ops/isotonic_calibration.cpp new file mode 100644 index 000000000..10a8c22dc --- /dev/null +++ b/twml/libtwml/src/ops/isotonic_calibration.cpp @@ -0,0 +1,81 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include +#include "tensorflow_utils.h" + +using namespace tensorflow; + +REGISTER_OP("IsotonicCalibration") +.Attr("T: {float, double}") +.Input("input: T") +.Input("xs: T") +.Input("ys: T") +.Output("output: T") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + // output shape should be the same as input shape. + c->set_output(0, c->input(0)); + return Status::OK(); +}).Doc(R"doc( + +This operation calibrates probabilities by fitting to a piece-wise non-decreasing function. + +Input + input: A tensor containing uncalibrated probabilities. + xs: A tensor containing the boundaries of the bins. + ys: A tensor contianing calibrated values for the corresponding bins. + +Expected Sizes: + input: [batch_size, num_labels]. + xs, ys: [num_labels, num_bins]. + +Expected Types: + input: float or double. + xs, ys: same as input. + +Outputs + output: A tensor containing calibrated probabilities with same shape and size as input. + +)doc"); + +template +class IsotonicCalibration : public OpKernel { + public: + explicit IsotonicCalibration(OpKernelConstruction* context) + : OpKernel(context) {} + + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& xs = context->input(1); + const Tensor& ys = context->input(2); + + Tensor* output = nullptr; + OP_REQUIRES_OK( + context, + context->allocate_output(0, input.shape(), &output)); + + try { + const twml::Tensor twml_input = TFTensor_to_twml_tensor(input); + const twml::Tensor twml_xs = TFTensor_to_twml_tensor(xs); + const twml::Tensor twml_ys = TFTensor_to_twml_tensor(ys); + twml::Tensor twml_output = TFTensor_to_twml_tensor(*output); + + twml::linearInterpolation(twml_output, twml_input, twml_xs, twml_ys); + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } + } +}; + +#define REGISTER(Type) \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("IsotonicCalibration") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + IsotonicCalibration); \ + +REGISTER(float); +REGISTER(double); diff --git a/twml/libtwml/src/ops/num_intra_op_threads.cpp b/twml/libtwml/src/ops/num_intra_op_threads.cpp new file mode 100644 index 000000000..7e5ef0cbf --- /dev/null +++ b/twml/libtwml/src/ops/num_intra_op_threads.cpp @@ -0,0 +1,39 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/common_shape_fns.h" + +using namespace tensorflow; + +REGISTER_OP("NumIntraOpThreads") +.Input("x: float32") +.Output("num_intra_op_threads: int32") +.SetShapeFn(tensorflow::shape_inference::ScalarShape) +.Doc(R"doc( +A tensorflow OP that returns the number of threads in the intra_op_parallelism pool +This is not part of the Tensorflow API as of the date of writing this doc. Hence, +a tensorflow operation is the best resort. +Input + x: Dummy placeholder so that constant folding is not done by TF GraphOptimizer. + Please refer https://github.com/tensorflow/tensorflow/issues/22546 for more + details. +Output + num_intra_op_threads: A scalar tensor corresponding to the number of threads in + the intra_op_parallelism pool +)doc"); + +class NumIntraOpThreads : public OpKernel { + public: + explicit NumIntraOpThreads(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + int num_intra_op_threads = context->device()->tensorflow_cpu_worker_threads()->num_threads; + Tensor* output_tensor = NULL; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &output_tensor)); + auto output_flat = output_tensor->flat(); + output_flat(0) = num_intra_op_threads; + } +}; + +REGISTER_KERNEL_BUILDER(Name("NumIntraOpThreads").Device(DEVICE_CPU), NumIntraOpThreads); diff --git a/twml/libtwml/src/ops/par_add.cpp b/twml/libtwml/src/ops/par_add.cpp new file mode 100644 index 000000000..c03c1ad89 --- /dev/null +++ b/twml/libtwml/src/ops/par_add.cpp @@ -0,0 +1,75 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/util/work_sharder.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/logging.h" +#include + +#include + +using namespace tensorflow; + +REGISTER_OP("ParAdd") + .Input("input_a: float") + .Input("input_b: float") + .Output("a_plus_b: float") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }); + + +class ParAddOp : public OpKernel { + public: + explicit ParAddOp(OpKernelConstruction* context) : OpKernel(context) { + } + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& input_tensor0 = context->input(0); + auto input_flat0 = input_tensor0.flat(); + const Tensor& input_tensor1 = context->input(1); + auto input_flat1 = input_tensor1.flat(); + + OP_REQUIRES(context, input_tensor0.shape() == input_tensor1.shape(), + errors::InvalidArgument("Input tensors must be identical shape.")); + + // Create an output tensor + Tensor* output_tensor = NULL; + OP_REQUIRES_OK(context, + context->allocate_output(0, + input_tensor0.shape(), + &output_tensor)); + auto output_flat = output_tensor->flat(); + + // PARALLEL ADD + const int N = input_flat0.size(); + + // retrieve the thread pool from the op context + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + + // Definition of the computation thread + auto task = [=, &input_flat0, &input_flat1, &output_flat](int64 start, int64 limit) { + for (; start < limit; ++start) { + output_flat(start) = input_flat0(start) + input_flat1(start); + } + }; + + // this is a heuristic. high number is likely to be sharded into smaller pieces + int64 cost_per_unit = 1; + + // let Tensorflow split up the work as it sees fit + Shard(worker_threads.num_threads, + worker_threads.workers, + N, + cost_per_unit, + task); + } +}; + +REGISTER_KERNEL_BUILDER(Name("ParAdd").Device(DEVICE_CPU), ParAddOp); + + diff --git a/twml/libtwml/src/ops/partition_sparse_tensor.cpp b/twml/libtwml/src/ops/partition_sparse_tensor.cpp new file mode 100644 index 000000000..4a210ba7f --- /dev/null +++ b/twml/libtwml/src/ops/partition_sparse_tensor.cpp @@ -0,0 +1,125 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include +#include "tensorflow_utils.h" + +using namespace tensorflow; + +REGISTER_OP("PartitionSparseTensorMod") +.Attr("T: {float, double}") +.Input("indices: int64") +.Input("values: T") +.Output("result: output_types") +.Attr("num_partitions: int") +.Attr("output_types: list({int64, float, double})") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); +}).Doc(R"doc( + +A tensorflow OP that partitions an input batch represented as a sparse tensor +(indices are [ids, keys]) into separate sparse tensors to more optimally place +sparse computations in distributed training. + +Inputs + indices: Indices from sparse tensor ([ids, keys] from the batch). + values: Batch values from the original features dict. + +Attr + num_partitions: Number of partitions to generate. + output_types: A list of types for the output tensors like + [tf.int64, tf.float32, tf.int64, tf.float32, ...] + The length must be 2 * num_partitions (see Outputs below) + +Outputs + List of dense tensors containing for each partition: + - partitioned indices tensor ([ids, keys] from partitioned batch) + - partitioned values tensor + The list lenth is 2 * num_partitions. Example: + [ [ids_1, keys_1], values_1, [ids_2, keys_2], values_2, ... ] +)doc"); + +template +class PartitionSparseTensorMod : public OpKernel { + private: + int64 num_partitions; + + public: + explicit PartitionSparseTensorMod(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("num_partitions", &num_partitions)); + OP_REQUIRES(context, num_partitions > 0, + errors::InvalidArgument("Number of partitions must be positive")); + } + + void Compute(OpKernelContext* context) override { + // grab input tensors + const Tensor& indices_tensor = context->input(0); // (ids, keys) + const Tensor& values_tensor = context->input(1); + + // check sizes + int64 num_keys = indices_tensor.shape().dim_size(0); + OP_REQUIRES(context, indices_tensor.dims() == 2, + errors::InvalidArgument("Indices tensor must be 2D [ids, keys]")); + OP_REQUIRES(context, indices_tensor.shape().dim_size(1) == 2, + errors::InvalidArgument("Indices tensor must have 2 cols [ids, keys]")); + OP_REQUIRES(context, values_tensor.shape().dim_size(0) == num_keys, + errors::InvalidArgument("Number of values must match number of keys")); + + // grab input vectors + auto indices = indices_tensor.flat(); + auto values = values_tensor.flat(); + + // count the number of features that fall in each partition + std::vector partition_counts(num_partitions); + + for (int i = 0; i < num_keys; i++) { + int64 key = indices(2 * i + 1); + int64 partition_id = key % num_partitions; + partition_counts[partition_id]++; + } + + // allocate outputs for each partition and keep references + std::vector output_indices_partitions; + std::vector output_values_partitions; + output_indices_partitions.reserve(num_partitions); + output_values_partitions.reserve(num_partitions); + + for (int i = 0; i < num_partitions; i++) { + Tensor *output_indices = nullptr, *output_values = nullptr; + TensorShape shape_indices = TensorShape({partition_counts[i], 2}); + TensorShape shape_values = TensorShape({partition_counts[i]}); + + OP_REQUIRES_OK(context, context->allocate_output(2 * i, shape_indices, &output_indices)); + OP_REQUIRES_OK(context, context->allocate_output(2 * i + 1, shape_values, &output_values)); + + output_indices_partitions.push_back(output_indices->flat().data()); + output_values_partitions.push_back(output_values->flat().data()); + } + + // assign a partition id to each feature + // populate tensors for each partition + std::vector partition_indices(num_partitions); + + for (int i = 0; i < num_keys; i++) { + int64 key = indices(2 * i + 1); + int64 pid = key % num_partitions; // partition id + int64 idx = partition_indices[pid]++; + + output_indices_partitions[pid][2 * idx] = indices(2 * i); + output_indices_partitions[pid][2 * idx + 1] = key / num_partitions; + output_values_partitions[pid][idx] = values(i); + } + } +}; + +#define REGISTER(Type) \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("PartitionSparseTensorMod") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + PartitionSparseTensorMod); \ + +REGISTER(float); +REGISTER(double); diff --git a/twml/libtwml/src/ops/percentile_discretizer_v2.cpp b/twml/libtwml/src/ops/percentile_discretizer_v2.cpp new file mode 100644 index 000000000..2a0dac7d8 --- /dev/null +++ b/twml/libtwml/src/ops/percentile_discretizer_v2.cpp @@ -0,0 +1,241 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/util/work_sharder.h" + +#include +#include "tensorflow_utils.h" + + +using namespace tensorflow; + +void CombinedComputeDiscretizers( + OpKernelContext*, + int64_t, + const twml::Map&, + int64_t); + +REGISTER_OP("PercentileDiscretizerV2") +.Attr("T: {float, double}") +.Input("input_ids: int64") +.Input("input_vals: T") +.Input("bin_ids: int64") +.Input("bin_vals: T") +.Input("feature_offsets: int64") +.Input("start_compute: int64") +.Input("end_compute: int64") +.Attr("output_bits: int") +.Attr("feature_ids: tensor = { dtype: DT_INT64 }") +.Attr("feature_indices: tensor = { dtype: DT_INT64 }") +.Attr("cost_per_unit: int") +.Output("new_keys: int64") +.Output("new_vals: T") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + // TODO: check sizes + c->set_output(0, c->input(0)); + c->set_output(1, c->input(0)); + return Status::OK(); +}).Doc(R"doc( + +This operation discretizes a tensor containing continuous features (if calibrated). + - note - choice of float or double should be consistent among inputs/output + +Input + input_ids(int64): A tensor containing input feature ids (direct from data record). + input_vals: A tensor containing input values at corresponding feature ids. + - i.e. input_ids[i] <-> input_vals[i] for each i + - float or double + bin_ids(int64): A tensor containing the discretized feature id for each bin. + bin_vals: A tensor containing the bin boundaries for values of a given feature. + - float or double + feature_offsets(int64): Specifies the starting location of bins for a given feature id. + start_compute(int64 scalar tensor): which index to start the computation at + end_compute(int64 scalar tensor): which index to end the computation right before + -> for example, (start_compute,end_compute)=(0,10) would compute on 0 thru 9 + output_bits(int): The maximum number of bits to use for the output IDs. + -> 2**out_bits must be greater than bin_ids.size + feature_ids(int64): 1D TensorProto of feature IDs seen during calibration + feature_indices(int64): 1D TensorProto of feature indices corresponding with feature_IDs + -> hint: look up make_tensor_proto: + proto_init = np.array(values, dtype=np.int64) + tensor_attr = tf.make_tensor_proto(my_proto_init) + cost_per_unit(int): An estimate of the number of CPU cycles (or nanoseconds + if not CPU-bound) to complete a unit of work. Overestimating creates too + many shards and CPU time will be dominated by per-shard overhead, such as + Context creation. Underestimating may not fully make use of the specified + parallelism. + +Outputs + new_keys(int64): The discretized feature ids with same shape and size as keys. + new_vals(float or double): The discretized values with the same shape and size as vals. + +Operation + Note that the discretization operation maps observation vectors to higher dimensional + observation vectors. Here, we describe this mapping. + + Let a calibrated feature observation be given by (F,x), where F is the ID of the + feature, and x is some real value (i.e., continuous feature). This kind of + representation is useful for the representation of sparse vectors, where there + are many zeros. + + For example, for a dense feature vector [1.2, 2.4, 3.6], we might have + (0, 1.2) (1, 2.4) and (2, 3.6), with feature IDs indicating the 0th, 1st, and 2nd + elements of the vector + + The disretizer performs the following operation: + (F,x) -> (map(x|F),1). + Hence, we have that map(x|F) is a new feature ID, and the value observed for that + feature is 1. We might read map(x|F) as 'the map of x for feature F'. + + For each feature F, we associate a (discrete, finite) set of new feature IDs, newIDs(F). + We will then have that F~(x) is in the set newIDs(F) for any value of x. Each set member + of newIDs(F) is associated with a 'bin', as defined by the bin boundaries given in + the bin_vals input array. For any two different feature IDs F and G, we have that + INTERSECT(newIDs(F),newIDs(G)) is the empty set + + Example - consider input vector with a single element, i.e. [x]. + Let's Discretize to one of 2 values, as follows: + Let F=0 for the ID of the single feature in the vector. + Let the bin boundary of feature F=0 be BNDRY(F) = BNDRY(0) since F=0 + Let newIDs(F) = newIDs(0) = {0,1} + Let map(x|F) = map(x|0) = 0 if x<=BNDRY else 1 + If we had another element y in the vector, i.e. [x, y], then we might additionally + Let F=1 for element y. + Let the bin boundary be BNDRY(F) = BNDRY(1) since F=1 + Let newIDs(F) = newIDs(1) = {2,3} (so as to have empty intersect with newIDs(0)) + Let map(x|F) = map(x|1) = 2 if x<=BNDRY else 3 + Consider vector observation [-0.1, 0.2]. We then represent this as [(0, -0.1), (1, 0.2)] + Let BNDRY(0) = BNDRY(1) = 0. When we discretize the vector, we get: + (0, -0.1) -> (map(-0.1|0), 1) = (0, 1) + (1, 0.2) -> (map( 0.2|1), 1) = (3, 1) + Our output vector is then represented sparsely as [(0, 1), (3, 1)], and the dense + representation of this could be [1, 0, 0, 1] + +)doc"); + +template +class PercentileDiscretizerV2 : public OpKernel { + public: + explicit PercentileDiscretizerV2(OpKernelConstruction* context) : OpKernel(context) { + // get the number of output bits + // for use with features that have not been calibrated + OP_REQUIRES_OK(context, + context->GetAttr("output_bits", &output_bits_)); + OP_REQUIRES_OK(context, + context->GetAttr("cost_per_unit", &cost_per_unit_)); + OP_REQUIRES(context, cost_per_unit_ >= 0, + errors::InvalidArgument("Must have cost_per_unit >= 0.")); + + // construct the ID_to_index hash map + Tensor feature_IDs; + Tensor feature_indices; + + // extract the tensors + OP_REQUIRES_OK(context, + context->GetAttr("feature_ids", &feature_IDs)); + OP_REQUIRES_OK(context, + context->GetAttr("feature_indices", &feature_indices)); + + // for access to the data + // int64_t data type is set in to_layer function of the calibrator objects in Python + auto feature_IDs_flat = feature_IDs.flat(); + auto feature_indices_flat = feature_indices.flat(); + + // verify proper dimension constraints + OP_REQUIRES(context, feature_IDs.shape() == feature_indices.shape(), + errors::InvalidArgument("feature_ids and feature_indices must be identical shape.")); + OP_REQUIRES(context, feature_IDs.shape().dims() == 1, + errors::InvalidArgument("feature_ids and feature_indices must be 1D.")); + + // reserve space in the hash map and fill in the values + int num_features = feature_IDs.shape().dim_size(0); + +#ifdef USE_DENSE_HASH + ID_to_index_.set_empty_key(0); + ID_to_index_.resize(num_features); +#else + ID_to_index_.reserve(num_features); +#endif // USE_DENSE_HASH + for (int i = 0 ; i < num_features ; i++) { + ID_to_index_[feature_IDs_flat(i)] = feature_indices_flat(i); + } + } + + void Compute(OpKernelContext* context) override { + CombinedComputeDiscretizers( + context, + output_bits_, + ID_to_index_, + cost_per_unit_); + } + + private: + twml::Map ID_to_index_; + int output_bits_; + int cost_per_unit_; +}; + +#define REGISTER(Type) \ + REGISTER_KERNEL_BUILDER( \ + Name("PercentileDiscretizerV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + PercentileDiscretizerV2); \ + +REGISTER(float); +REGISTER(double); + +void CombinedComputeDiscretizers( + OpKernelContext* context, + int64_t output_bits, + const twml::Map &ID_to_index, + int64_t cost_per_unit) { + const Tensor& keys = context->input(0); + const Tensor& vals = context->input(1); + const Tensor& bin_ids = context->input(2); + const Tensor& bin_vals = context->input(3); + const Tensor& feature_offsets = context->input(4); + + uint64 full_size = keys.dim_size(0); + const int total_size = static_cast(full_size); + TensorShape output_shape = {total_size}; + + Tensor* new_keys = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &new_keys)); + Tensor* new_vals = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(1, output_shape, &new_vals)); + + try { + twml::Tensor out_keys_ = TFTensor_to_twml_tensor(*new_keys); + twml::Tensor out_vals_ = TFTensor_to_twml_tensor(*new_vals); + + const twml::Tensor in_keys_ = TFTensor_to_twml_tensor(keys); + const twml::Tensor in_vals_ = TFTensor_to_twml_tensor(vals); + const twml::Tensor bin_ids_ = TFTensor_to_twml_tensor(bin_ids); + const twml::Tensor bin_vals_ = TFTensor_to_twml_tensor(bin_vals); + const twml::Tensor feature_offsets_ = TFTensor_to_twml_tensor(feature_offsets); + + // retrieve the thread pool from the op context + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + + // Definition of the computation thread + auto task = [&](int64 start, int64 limit) { + twml::discretizerInfer(out_keys_, out_vals_, + in_keys_, in_vals_, + bin_ids_, bin_vals_, + feature_offsets_, output_bits, + ID_to_index, + start, limit, + start); + }; + + // let Tensorflow split up the work as it sees fit + Shard(worker_threads.num_threads, + worker_threads.workers, + full_size, + static_cast(cost_per_unit), + task); + } catch (const std::exception &e) { + context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); + } +} diff --git a/twml/libtwml/src/ops/resource_utils.h b/twml/libtwml/src/ops/resource_utils.h new file mode 100644 index 000000000..a41fe6845 --- /dev/null +++ b/twml/libtwml/src/ops/resource_utils.h @@ -0,0 +1,126 @@ +#pragma once + +#include + +#include +#include +#include + +// Add these to make gcc ignore the warnings from tensorflow. +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wsign-compare" + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/resource_op_kernel.h" + +#pragma GCC diagnostic pop + +#include +#include + +template +void unrefHandle(T *handle) { + handle->Unref(); +} + +template +using unique_handle = std::unique_ptr >; + +// as std::type_index is not abi compatible, we bypass the hash_code checks. +// https://github.com/tensorflow/tensorflow/commit/15275d3a14c77e2244ae1155f93243256f08e3ed +#ifdef __APPLE__ +template +Status CreateTwmlResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) { + return ctx->resource_manager()->Create(p.container(), p.name(), value); +} + +template +Status LookupTwmlResource(OpKernelContext* ctx, const ResourceHandle& p, + T** value) { + return ctx->resource_manager()->Lookup(p.container(), p.name(), value); +} +#endif // __APPLE__ + +template +unique_handle getHandle(tensorflow::OpKernelContext* context, int input_idx) { + using namespace tensorflow; + T *ptr = nullptr; +#ifdef __APPLE__ + auto s = LookupTwmlResource(context, HandleFromInput(context, input_idx), &ptr); +#else + auto s = LookupResource(context, HandleFromInput(context, input_idx), &ptr); +#endif // __APPLE__ + + if (!s.ok()) { + throw std::runtime_error("Failed to get resource handle"); + } + return unique_handle(ptr, unrefHandle); +} + +template +const uint8_t *getInputBytes(const Tensor &input, int id) { + return reinterpret_cast(input.flat().data()); +} + +template<> +inline const uint8_t *getInputBytes(const Tensor &input, int id) { + return reinterpret_cast(input.flat()(id).c_str()); +} + +template +const int getBatchSize(const Tensor &input) { + return 1; +} + +template<> +inline const int getBatchSize(const Tensor &input) { + return static_cast(input.NumElements()); +} + +class DataRecordResource : public ResourceBase { + public: + Tensor input; + int64 num_labels; + int64 num_weights; + twml::DataRecord common; + std::vector records; + twml::Map *keep_map; + string DebugString() const override { return "DataRecords resource"; } +}; + +// A thin layer around batch of HashedDataRecords +class HashedDataRecordResource : public ResourceBase { + public: + Tensor input; + int64 total_size; + int64 num_labels; + int64 num_weights; + twml::HashedDataRecord common; + std::vector records; + string DebugString() const override { return "HashedDataRecord Resource"; } +}; + +#define TF_CHECK_STATUS(fn) do { \ + Status s = fn; \ + if (!s.ok()) return s; \ + } while (0) + +template +Status makeResourceHandle(OpKernelContext* context, int out_idx, ResourceType **resource_) { + static std::atomic id; + Tensor* handle_tensor; + TF_CHECK_STATUS(context->allocate_output(out_idx, TensorShape({}), &handle_tensor)); + + ResourceType *resource = new ResourceType(); + const auto resource_name = typeid(ResourceType).name() + std::to_string(id++); + ResourceHandle handle = MakePerStepResourceHandle(context, resource_name); +#ifdef __APPLE__ + TF_CHECK_STATUS(CreateTwmlResource(context, handle, resource)); +#else + TF_CHECK_STATUS(CreateResource(context, handle, resource)); +#endif // __APPLE__ + handle_tensor->scalar()() = handle; + + *resource_ = resource; + return Status::OK(); +} diff --git a/twml/libtwml/src/ops/scripts/get_inc.py b/twml/libtwml/src/ops/scripts/get_inc.py new file mode 100644 index 000000000..c50edfa90 --- /dev/null +++ b/twml/libtwml/src/ops/scripts/get_inc.py @@ -0,0 +1,5 @@ +"""Gets the path of headers for the current Tensorflow library""" + +import tensorflow.compat.v1 as tf + +print(tf.sysconfig.get_include(), end='') diff --git a/twml/libtwml/src/ops/scripts/get_inc.sh b/twml/libtwml/src/ops/scripts/get_inc.sh new file mode 100644 index 000000000..5cb064338 --- /dev/null +++ b/twml/libtwml/src/ops/scripts/get_inc.sh @@ -0,0 +1,2 @@ +#!/bin/sh +PEX_INTERPRETER=1 "$PYTHON_ENV" "$LIBTWML_HOME"/src/ops/scripts/get_inc.py diff --git a/twml/libtwml/src/ops/scripts/get_lib.py b/twml/libtwml/src/ops/scripts/get_lib.py new file mode 100644 index 000000000..7150c48b7 --- /dev/null +++ b/twml/libtwml/src/ops/scripts/get_lib.py @@ -0,0 +1,5 @@ +"""Gets the path of headers for the current Tensorflow library""" + +import tensorflow.compat.v1 as tf + +print(tf.sysconfig.get_lib(), end='') diff --git a/twml/libtwml/src/ops/scripts/get_lib.sh b/twml/libtwml/src/ops/scripts/get_lib.sh new file mode 100644 index 000000000..1b9d802b6 --- /dev/null +++ b/twml/libtwml/src/ops/scripts/get_lib.sh @@ -0,0 +1,2 @@ +#!/bin/sh +PEX_INTERPRETER=1 "$PYTHON_ENV" "$LIBTWML_HOME"/src/ops/scripts/get_lib.py diff --git a/twml/libtwml/src/ops/scripts/symlink.sh b/twml/libtwml/src/ops/scripts/symlink.sh new file mode 100644 index 000000000..2ddb76371 --- /dev/null +++ b/twml/libtwml/src/ops/scripts/symlink.sh @@ -0,0 +1,12 @@ +#!/bin/sh + +#Needed to create a "nice" symlink to _pywrap_tensorflow_internal.so so +#that cmake can link with the library properly. + +#This library is only needed for streaming datasets and is linked with +#libtwml_tf_data.so which will not be used at runtime. + +TF_PYTHON_LIB_DIR=$(PEX_INTERPRETER=1 "$PYTHON_ENV" "$TWML_HOME"/backends/tensorflow/src/scripts/get_lib.py) +TF_INTERNAL_LIB=$TWML_HOME/backends/tensorflow/twml/lib/libtensorflow_internal.so +rm -f "$TF_INTERNAL_LIB" +ln -s "$TF_PYTHON_LIB_DIR"/python/_pywrap_tensorflow_internal.so "$TF_INTERNAL_LIB" diff --git a/twml/libtwml/src/ops/sleep_op.cpp b/twml/libtwml/src/ops/sleep_op.cpp new file mode 100644 index 000000000..dd9a1834c --- /dev/null +++ b/twml/libtwml/src/ops/sleep_op.cpp @@ -0,0 +1,51 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include +#include + +using namespace tensorflow; + +REGISTER_OP("Sleep") +.Input("num_milliseconds: int32") +.Output("sleep_time_in_ms: int32") +.SetShapeFn(tensorflow::shape_inference::ScalarShape) +.Doc(R"doc( +A tensorflow OP that sleeps for specified number of milliseconds. +This is a proxy to determine the number of inter_op_parallelism pool. +This is not part of the Tensorflow API as of the date of writing this +doc. Hence, a tensorflow operation is the best resort. +Input + num_milliseconds: A scalar tensor corresponding to the number + of milliseconds the operation should sleep for +Output + sleep_time_in_ms: A scalar tensor corresponding to the + actual number of milliseconds for which the operation slept +)doc"); + +class SleepOp : public OpKernel { + public: + explicit SleepOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& input_tensor = context->input(0); + auto input = input_tensor.flat(); + + // Sleep for specified milliseconds + auto start = std::chrono::high_resolution_clock::now(); + std::this_thread::sleep_for(std::chrono::milliseconds(input(0))); + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration elapsed = end-start; + + // Set the output tensor + Tensor* output_tensor = NULL; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &output_tensor)); + auto output_flat = output_tensor->flat(); + output_flat(0) = elapsed.count(); + } +}; + +REGISTER_KERNEL_BUILDER(Name("Sleep").Device(DEVICE_CPU), SleepOp); diff --git a/twml/libtwml/src/ops/sparse_normalization.cpp b/twml/libtwml/src/ops/sparse_normalization.cpp new file mode 100644 index 000000000..9b079429c --- /dev/null +++ b/twml/libtwml/src/ops/sparse_normalization.cpp @@ -0,0 +1,378 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" + +using namespace tensorflow; + +REGISTER_OP("SparseMaxNorm") +.Attr("epsilon: float") +.Input("max_values: Ref(float)") +.Input("indices: int64") +.Input("values: float") +.Input("is_training: bool") +.Output("updated_max_values: Ref(float)") +.Output("normalized_values: float") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that normalizes a batch of sparse inputs based on the current maximum value. + +Input + max_values: float tensor variable representing the max values seen so far. + indices: int64 tensor representing indices representing a feature. + values: float tensor representing values for the current batch. + is_training: bool tensor specifying if the op should be run in training mode or not. + +Outputs + updated_max_values: max_values updated with the current batch. + normalized_values: Input values normalized by the max value seen so far. + +The pseudo code for normalization can be seen below: + + # During training / inference + for i, idx in enumerate(indices): + updated_max_values[idx] = max(max_values[idx], abs(values[i])) + normalized_values[i] = values[i] / updated_max_values[idx] + +)doc"); + +class SparseMaxNorm : public OpKernel { + private: + float epsilon_; + + public: + explicit SparseMaxNorm(OpKernelConstruction *context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon_)); + } + + void Compute(OpKernelContext *context) override { + // We always return the input ref. + context->forward_ref_input_to_ref_output(0, 0); + Tensor max_values_tensor = context->mutable_input(0, false); + + OP_REQUIRES(context, max_values_tensor.IsInitialized(), + errors::FailedPrecondition("Attempting to use uninitialized " + "parameters: ", + requested_input(0))); + + const Tensor &indices_tensor = context->input(1); + const Tensor &values_tensor = context->input(2); + const Tensor &is_training_tensor = context->input(3); + + const auto indices = indices_tensor.flat(); + const auto values = values_tensor.flat(); + const bool is_training = is_training_tensor.scalar()(); + + auto max_values = max_values_tensor.flat(); + Tensor *normalized_values_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(1, values_tensor.shape(), + &normalized_values_tensor)); + + auto normalized_values = normalized_values_tensor->flat(); + + const int64 N = indices.size(); + + for (int64 i = 0; i < N; i++) { + int64 idx = indices(i); + float value = values(i); + float max_value = std::max(max_values(idx), std::abs(value)); + + // Guaranteed to be between [-1, 1]. + normalized_values(i) = value / std::max(max_value, epsilon_); + + if (is_training) { + max_values(idx) = max_value; + } + } + } +}; + +REGISTER_OP("SparseBatchNorm") +.Attr("input_size: int") +.Attr("epsilon: float") +.Input("means: Ref(float)") +.Input("variances: Ref(float)") +.Input("indices: int64") +.Input("values: float") +.Input("is_training: bool") +.Output("updated_means: Ref(float)") +.Output("updated_vars: Ref(float)") +.Output("normalized_values: float") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that performs batch normalization. + +Attr + input_size: Size of the inputs. + epsilon: The minimum value of the variance. + +Input + mean: float tensor variable representing the running mean seen so far. + variances: float tensor variable representing the running variance seen so far. + indices: int64 tensor representing indices representing a feature. + values: float tensor representing values for the current batch. + is_training: bool tensor specifying if the op should be run in training mode or not. + +Outputs + updated_means: mean updated with the current batch. + updated_vars: variances updated with the current batch. + normalized_values: Input values normalized by the max value seen so far. + +The pseudo code for normalization can be seen below: + + if is_training: + means, variances = update_metrics(means, variances, values) + + normalized_values = (values - means) / sqrt(variances + epsilon) + return normalized_values * gamma + beta + +)doc"); + +class SparseBatchNorm : public OpKernel { + private: + std::vector counts_; + std::vector m2s_; + float epsilon_; + + public: + explicit SparseBatchNorm(OpKernelConstruction *context) : OpKernel(context) { + int64 input_size; + OP_REQUIRES_OK(context, context->GetAttr("input_size", &input_size)); + OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon_)); + counts_.resize(input_size); + m2s_.resize(input_size); + } + + void Compute(OpKernelContext *context) override { + // We always return the input ref. + context->forward_ref_input_to_ref_output(0, 0); + context->forward_ref_input_to_ref_output(1, 1); + + Tensor means_tensor = context->mutable_input(0, true); + Tensor variances_tensor = context->mutable_input(1, true); + + OP_REQUIRES(context, means_tensor.IsInitialized(), + errors::FailedPrecondition("Attempting to use uninitialized " + "parameters: ", + requested_input(0))); + + OP_REQUIRES(context, variances_tensor.IsInitialized(), + errors::FailedPrecondition("Attempting to use uninitialized " + "parameters: ", + requested_input(1))); + + const Tensor &indices_tensor = context->input(2); + const Tensor &values_tensor = context->input(3); + const Tensor &is_training_tensor = context->input(4); + + const auto indices = indices_tensor.flat(); + const auto values = values_tensor.flat(); + const bool is_training = is_training_tensor.scalar()(); + + auto means = means_tensor.flat(); + auto variances = variances_tensor.flat(); + Tensor *normalized_values_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(2, values_tensor.shape(), + &normalized_values_tensor)); + + auto normalized_values = normalized_values_tensor->flat(); + const int64 N = indices.size(); + + if (is_training) { + // Accumulate, mean, count, sum of squared differences. + // Reference wiki: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm + // Reference paper: + // https://www.jstor.org/stable/1266577?seq=1#page_scan_tab_contents + for (int64 i = 0; i < N; i++) { + int64 idx = indices(i); + int64 count = counts_[idx] + 1; + + float value = values(i); + float old_mean = means(idx); + float old_delta = value - old_mean; + float new_mean = old_mean + old_delta / count; + float new_delta = value - new_mean; + + counts_[idx] = count; + m2s_[idx] += new_delta * old_delta; + means(idx) = new_mean; + variances(idx) = m2s_[idx] / count; + } + } + + // Normalize the values + for (int64 i = 0; i < N; i++) { + int64 idx = indices(i); + float stdev = std::sqrt(variances(idx) + epsilon_); + normalized_values(i) = (values(i) - means(idx)) / stdev; + } + } +}; + +REGISTER_OP("SparseMaxNormInference") +.Attr("epsilon: float") +.Input("max_values: float") +.Input("indices: int64") +.Input("values: float") +.Output("normalized_values: float") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that normalizes a batch of sparse inputs based on the current maximum value. +This is the inference OP. + +Input + max_values: float tensor representing the max values seen so far. + indices: int64 tensor representing indices representing a feature. + values: float tensor representing values for the current batch. + +Outputs + normalized_values: Input values normalized by the max value seen so far. + +The pseudo code for normalization can be seen below: + + # During inference + for i, idx in enumerate(indices): + updated_max_values[idx] = max(max_values[idx], abs(values[i])) + normalized_values[i] = values[i] / updated_max_values[idx] + +)doc"); + +class SparseMaxNormInference : public OpKernel { + private: + float epsilon_; + + public: + explicit SparseMaxNormInference(OpKernelConstruction *context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon_)); + } + + void Compute(OpKernelContext *context) override { + const Tensor &max_values_tensor = context->input(0); + const Tensor &indices_tensor = context->input(1); + const Tensor &values_tensor = context->input(2); + + const auto max_values = max_values_tensor.flat(); + const auto indices = indices_tensor.flat(); + const auto values = values_tensor.flat(); + + Tensor *normalized_values_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, values_tensor.shape(), + &normalized_values_tensor)); + + auto normalized_values = normalized_values_tensor->flat(); + + const int64 N = indices.size(); + + for (int64 i = 0; i < N; i++) { + int64 idx = indices(i); + float value = values(i); + float max_value = std::max(max_values(idx), std::abs(value)); + + // Guaranteed to be between [-1, 1]. + normalized_values(i) = value / std::max(max_value, epsilon_); + } + } +}; + +REGISTER_OP("SparseMaxNormTraining") +.Attr("epsilon: float") +.Input("max_values: float") +.Input("indices: int64") +.Input("values: float") +.Output("updated_max_values: float") +.Output("normalized_values: float") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that normalizes a batch of sparse inputs based on the current maximum value. +This is the training OP. + +Input + max_values: float tensor variable representing the max values seen so far. + indices: int64 tensor representing indices representing a feature. + values: float tensor representing values for the current batch. + +Outputs + updated_max_values: max_values updated with the current batch. + normalized_values: Input values normalized by the max value seen so far. + +The pseudo code for normalization can be seen below: + + # During training + for i, idx in enumerate(indices): + updated_max_values[idx] = max(max_values[idx], abs(values[i])) + normalized_values[i] = values[i] / updated_max_values[idx] + +)doc"); + +class SparseMaxNormTraining : public OpKernel { + private: + float epsilon_; + + public: + explicit SparseMaxNormTraining(OpKernelConstruction *context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon_)); + } + + void Compute(OpKernelContext *context) override { + const Tensor &max_values_tensor = context->input(0); + const Tensor &indices_tensor = context->input(1); + const Tensor &values_tensor = context->input(2); + + const auto max_values = max_values_tensor.flat(); + const auto indices = indices_tensor.flat(); + const auto values = values_tensor.flat(); + + Tensor *updated_max_values_tensor = nullptr; + Tensor *normalized_values_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, max_values_tensor.shape(), + &updated_max_values_tensor)); + OP_REQUIRES_OK(context, context->allocate_output(1, values_tensor.shape(), + &normalized_values_tensor)); + + auto updated_max_values = updated_max_values_tensor->flat(); + auto normalized_values = normalized_values_tensor->flat(); + + const int64 N = indices.size(); + + // This copy is needed because the values of updated_max_values are originally garbage. + // Also note that N is not the same as max_values.size() + std::copy(max_values.data(), max_values.data() + max_values.size(), updated_max_values.data()); + + for (int64 i = 0; i < N; i++) { + int64 idx = indices(i); + float value = values(i); + float updated_max_value = std::max(updated_max_values(idx), std::abs(value)); + // Guaranteed to be between [-1, 1]. + normalized_values(i) = value / std::max(updated_max_value, epsilon_); + // Saving the updated_max_values + updated_max_values(idx) = updated_max_value; + } + } +}; + + + + +REGISTER_KERNEL_BUILDER( + Name("SparseMaxNorm") + .Device(DEVICE_CPU), + SparseMaxNorm); + +REGISTER_KERNEL_BUILDER( + Name("SparseBatchNorm") + .Device(DEVICE_CPU), + SparseBatchNorm); + +REGISTER_KERNEL_BUILDER( + Name("SparseMaxNormInference") + .Device(DEVICE_CPU), + SparseMaxNormInference); + +REGISTER_KERNEL_BUILDER( + Name("SparseMaxNormTraining") + .Device(DEVICE_CPU), + SparseMaxNormTraining); diff --git a/twml/libtwml/src/ops/tensor_record.cpp b/twml/libtwml/src/ops/tensor_record.cpp new file mode 100644 index 000000000..ad044e378 --- /dev/null +++ b/twml/libtwml/src/ops/tensor_record.cpp @@ -0,0 +1,692 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include +#include "tensorflow_utils.h" +#include "resource_utils.h" + +#include +using std::string; + +REGISTER_OP("GetStringTensorsFromDataRecord") +.Attr("feature_id: int") +.Input("data_record_handle: resource") +.Output("ids: int64") +.Output("strings: string") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that decodes and returns string tensors from the data record. + +Attr + feature_id: The hashed id of the feature name. + +Input + data_record_handle: Resource handle to DataRecord. + +Outputs + ids: A 1D int64 tensor representing the input index in a given batch. + strings: A 1D string tensor representing the decoded strings from the batch. +)doc"); + +REGISTER_OP("GetStringTensorsFromHashedDataRecord") +.Attr("feature_id: int") +.Input("hashed_data_record_handle: resource") +.Output("ids: int64") +.Output("strings: string") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that decodes and returns string tensors from the hashed data record. + +Attr + feature_id: The hashed id of the feature name. + +Input + data_record_handle: Resource handle to DataRecord. + +Outputs + ids: A 1D int64 tensor representing the input index in a given batch. + strings: A 1D string tensor representing the decoded strings from the batch. +)doc"); + +template +class GetStringTensorsOp : public OpKernel { + private: + int64 feature_id; + + public: + explicit GetStringTensorsOp(OpKernelConstruction *context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("feature_id", &feature_id)); + } + + void Compute(OpKernelContext *context) override { + auto handle = getHandle(context, 0); + const int64 batch_size = static_cast(handle->records.size()); + const auto &records = handle->records; + + try { + int64 total_size = 0; + for (const auto &record : records) { + try { + const auto &tensor = record.getRawTensor(feature_id); + total_size += static_cast(tensor.getNumElements()); + } catch(const std::out_of_range &err) { + LOG(WARNING) << "Ignoring missing string tensor with key: " << feature_id << std::endl; + continue; + } + } + + twml::ThriftReader reader(nullptr); + TensorShape shape = {total_size}; + Tensor *strings_tensor = nullptr; + Tensor *ids_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &ids_tensor)); + OP_REQUIRES_OK(context, context->allocate_output(1, shape, &strings_tensor)); + + auto strings_data = strings_tensor->flat().data(); + auto ids_data = ids_tensor->flat().data(); + + for (int64 i = 0; i < batch_size; i++) { + const auto &record = records[i]; + try { + const twml::RawTensor &tensor = record.getRawTensor(feature_id); + const uint8_t *buffer = static_cast(tensor.getData()); + const int64 num_strings = static_cast(tensor.getNumElements()); + reader.setBuffer(buffer); + + for (int64 j = 0; j < num_strings; j++) { + const uint8_t *curr_begin = nullptr; + const auto curr_length = reader.getRawBuffer(&curr_begin); + strings_data[j] = std::string(curr_begin, curr_begin + curr_length); + ids_data[j] = i; + } + ids_data += num_strings; + strings_data += num_strings; + } catch(const std::out_of_range &err) { + continue; + } + } + } catch(const std::exception &err) { + context->CtxFailureWithWarning(errors::InvalidArgument(err.what())); + } + } +}; + +REGISTER_KERNEL_BUILDER( + Name("GetStringTensorsFromDataRecord") + .Device(DEVICE_CPU), + GetStringTensorsOp); + +REGISTER_KERNEL_BUILDER( + Name("GetStringTensorsFromHashedDataRecord") + .Device(DEVICE_CPU), + GetStringTensorsOp); + +REGISTER_OP("GetTensorsFromDataRecord") +.Attr("assert_shape: bool") +.Attr("feature_id: int") +.Input("data_record_handle: resource") +.Output("output: string") +.Output("out_shape: int64") +.Output("out_type: string") +.Output("out_endian: uint8") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that decodes and returns tensors from the data record. + +Attr + feature_id: The hashed id of the feature name. + +Input + data_record_handle: Resource handle to DataRecord. + +Outputs + output: A 2D byte tensor representing the requested feature. + out_shape: A tensor containing [batch_size, thrift_shape]. + out_type: Output type returned as a string tensor of size 1. + out_endian: Endianness of the bytes returned a tensor of size 1. 0: litte, 1: big. +)doc"); + +REGISTER_OP("GetTensorsFromHashedDataRecord") +.Attr("assert_shape: bool") +.Attr("feature_id: int") +.Input("hashed_data_record_handle: resource") +.Output("output: string") +.Output("out_shape: int64") +.Output("out_type: string") +.Output("out_endian: uint8") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that returns decodes and tensors from the hashed data record. + +Attr + feature_id: The hashed id of the feature name. + +Input + data_record_handle: Resource handle to DataRecord. + +Outputs + output: A 2D byte tensor representing the requested feature. + out_shape: A tensor containing [batch_size, thrift_shape]. + out_type: Output type returned as a string tensor of size 1. + out_endian: Endianness of the bytes returned a tensor of size 1. 0: litte, 1: big. +)doc"); + +template +class GetTensorsOp : public OpKernel { + private: + bool assert_shape; + int64 feature_id; + + public: + explicit GetTensorsOp(OpKernelConstruction *context) + : OpKernel(context), assert_shape(true) { + OP_REQUIRES_OK(context, context->GetAttr("assert_shape", &assert_shape)); + OP_REQUIRES_OK(context, context->GetAttr("feature_id", &feature_id)); + } + + void Compute(OpKernelContext *context) override { + auto handle = getHandle(context, 0); + uint64 batch_size = handle->records.size(); + const auto &records = handle->records; + + try { + TensorShape raw_shape = {static_cast(batch_size)}; + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, raw_shape, &output_tensor)); + auto output_flat = output_tensor->flat(); + auto output_data = output_flat.data(); + + twml_type type = TWML_TYPE_UNKNOWN; + bool is_big_endian = false; + + std::vector shape(1, batch_size); + uint64 length = 0; + + for (auto record : records) { + const twml::RawTensor tensor = record.getRawTensor(feature_id); + const auto &curr_dims = tensor.getDims(); + const auto curr_type = tensor.getType(); + const bool curr_is_big_endian = tensor.is_big_endian(); + const uint64 curr_length = tensor.getRawLength(); + + // Create the output tensor based on first tensor + if (shape.size() == 1) { + // Push the shape of individual tensors into shape + shape.reserve(curr_dims.size() + 1); + shape.insert(shape.end(), curr_dims.begin(), curr_dims.end()); + type = curr_type; + is_big_endian = curr_is_big_endian; + length = curr_length; + + } else { + if (assert_shape) { + // Assert shape of all tensors is the same. + bool is_same_shape = std::equal(shape.begin() + 1, shape.end(), curr_dims.begin()); + + if (!is_same_shape || length != curr_length) { + throw std::runtime_error("TensorShape mismatch for feature_id: " + + std::to_string(feature_id)); + } + } + + // Assert type and endianness of all tensors is the same. + if (type != curr_type || is_big_endian != curr_is_big_endian) { + throw std::runtime_error("Tensor type mismatch for feature_id: " + + std::to_string(feature_id)); + } + } + + // Copy from datarecord to output + const uint8 *tensor_data = reinterpret_cast(tensor.getData()); + *output_data = std::string(tensor_data, tensor_data + curr_length); + + // Increment it for the next tensor in the batch. + output_data++; + } + + Tensor *shape_tensor = nullptr; + TensorShape shape_shape = {static_cast(shape.size())}; + OP_REQUIRES_OK(context, context->allocate_output(1, shape_shape, &shape_tensor)); + auto shape_flat = shape_tensor->flat(); + for (int i = 0; i < static_cast(shape.size()); i++) { + shape_flat(i) = shape[i]; + } + + Tensor* type_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(2, {}, &type_tensor)); + type_tensor->scalar()() = twml::getTypeName(type); + + Tensor* endian_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(3, {}, &endian_tensor)); + endian_tensor->scalar()() = is_big_endian; + } catch(const std::exception &err) { + context->CtxFailureWithWarning(errors::InvalidArgument(err.what())); + } + } +}; + +REGISTER_KERNEL_BUILDER( + Name("GetTensorsFromDataRecord") + .Device(DEVICE_CPU), + GetTensorsOp); + +REGISTER_KERNEL_BUILDER( + Name("GetTensorsFromHashedDataRecord") + .Device(DEVICE_CPU), + GetTensorsOp); + +REGISTER_OP("GetTensorsWithMissingMaskFromDataRecord") +.Attr("assert_shape: bool") +.Attr("feature_id: int") +.Attr("default_shape: list(int)") +.Attr("dtype_size: int") +.Input("data_record_handle: resource") +.Output("output: string") +.Output("out_type: string") +.Output("out_endian: uint8") +.Output("is_found: bool") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that decodes and returns tensors from the data record. + +Attr + assert_shape: Specifies if the shape needs to be same across the batch. + feature_id: The hashed id of the feature name. + default_shape: Expected shape of output tensor. + dtype_size: expected size of each element. + +Input + data_record_handle: Resource handle to DataRecord. + +Outputs + output: A 2D byte tensor representing the requested feature. + out_type: A string tensor represnting the type. + out_endian: Endianness of the bytes returned a tensor of size 1. 0: litte, 1: big. + is_missing: A boolean tensor of length batch_size represnting if the tensor was found for an input. +)doc"); + +REGISTER_OP("GetTensorsWithMissingMaskFromHashedDataRecord") +.Attr("assert_shape: bool") +.Attr("feature_id: int") +.Attr("default_shape: list(int)") +.Attr("dtype_size: int") +.Input("hashed_data_record_handle: resource") +.Output("output: string") +.Output("out_type: string") +.Output("out_endian: uint8") +.Output("is_found: bool") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that decodes and returns tensors from the data record. + +Attr + assert_shape: Specifies if the shape needs to be same across the batch. + feature_id: The hashed id of the feature name. + default_shape: Expected shape of output tensor. + dtype_size: expected size of each element. + +Input + hashed_data_record_handle: Resource handle to HashedDataRecord. + +Outputs + output: A 2D byte tensor representing the requested feature. + out_type: A string tensor represnting the type. + out_endian: Endianness of the bytes returned a tensor of size 1. 0: litte, 1: big. + is_missing: A boolean tensor of length batch_size represnting if the tensor was found for an input. +)doc"); + +template +class GetTensorsWithMissingMaskOp : public OpKernel { + private: + bool assert_shape; + int64 feature_id; + int64 dtype_size; + std::vector shape; + + public: + explicit GetTensorsWithMissingMaskOp(OpKernelConstruction *context) + : OpKernel(context), assert_shape(true) { + OP_REQUIRES_OK(context, context->GetAttr("assert_shape", &assert_shape)); + OP_REQUIRES_OK(context, context->GetAttr("feature_id", &feature_id)); + OP_REQUIRES_OK(context, context->GetAttr("default_shape", &shape)); + OP_REQUIRES_OK(context, context->GetAttr("dtype_size", &dtype_size)); + } + + void Compute(OpKernelContext *context) override { + auto handle = getHandle(context, 0); + uint64 batch_size = handle->records.size(); + const auto &records = handle->records; + + try { + TensorShape raw_shape = {static_cast(batch_size)}; + Tensor* output_tensor = nullptr; + Tensor* is_found_tensor = nullptr; + + OP_REQUIRES_OK(context, context->allocate_output(0, raw_shape, &output_tensor)); + OP_REQUIRES_OK(context, context->allocate_output(3, raw_shape, &is_found_tensor)); + + auto output_flat = output_tensor->flat(); + auto output_data = output_flat.data(); + auto is_found_data = is_found_tensor->flat().data(); + + twml_type type = TWML_TYPE_UNKNOWN; + bool is_big_endian = false; + + uint64 length = std::accumulate(shape.begin(), shape.end(), dtype_size, std::multiplies()); + for (auto record : records) { + try { + const twml::RawTensor tensor = record.getRawTensor(feature_id); + const auto &curr_dims = tensor.getDims(); + const auto curr_type = tensor.getType(); + const bool curr_is_big_endian = tensor.is_big_endian(); + const uint64 curr_length = tensor.getRawLength(); + + if (type == TWML_TYPE_UNKNOWN) { + type = curr_type; + is_big_endian = curr_is_big_endian; + // FloatTensors are stored as a list of doubles. + // If the requested dtype_size is 4, update the length. + // NOTE: All the missing tensors before this have wrong length, this is fixed at the end. + if (type == TWML_TYPE_DOUBLE && is_big_endian && dtype_size == 4) { + length = length * 2; + } + } else { + // Assert type and endianness of all tensors is the same. + if (type != curr_type || is_big_endian != curr_is_big_endian) { + throw std::runtime_error("Tensor type mismatch for feature_id: " + + std::to_string(feature_id)); + } + } + + // Assert shape of all tensors is the same. + if (assert_shape && type != TWML_TYPE_UNKNOWN) { + // Assert shape of all tensors is the same. + bool is_same_shape = std::equal(shape.begin(), shape.end(), curr_dims.begin()); + + if (!is_same_shape || length != curr_length) { + throw std::runtime_error("TensorShape mismatch for feature_id: " + + std::to_string(feature_id)); + } + } + + // Copy from datarecord to output + const uint8 *tensor_data = reinterpret_cast(tensor.getData()); + *output_data = std::string(tensor_data, tensor_data + curr_length); + *is_found_data = true; + } catch(const std::out_of_range &err) { + *output_data = std::string(); + output_data->resize(length); + *is_found_data = false; + } + + // Increment it for the next tensor in the batch. + output_data++; + is_found_data++; + } + + // Reset pointers to the beginning + output_data = output_flat.data(); + is_found_data = is_found_tensor->flat().data(); + + // Resize any missing tensors before type (and hence true length) was known. + if (type == TWML_TYPE_DOUBLE) { + for (int64 i = 0; i < static_cast(records.size()); i++) { + if (!is_found_data[i]) { + output_data[i].resize(length); + } + } + } + + Tensor* type_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(1, {}, &type_tensor)); + type_tensor->scalar()() = twml::getTypeName(type); + + Tensor* endian_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(2, {}, &endian_tensor)); + endian_tensor->scalar()() = is_big_endian; + } catch(const std::exception &err) { + context->CtxFailureWithWarning(errors::InvalidArgument(err.what())); + } + } +}; + +REGISTER_KERNEL_BUILDER( + Name("GetTensorsWithMissingMaskFromDataRecord") + .Device(DEVICE_CPU), + GetTensorsWithMissingMaskOp); + +REGISTER_KERNEL_BUILDER( + Name("GetTensorsWithMissingMaskFromHashedDataRecord") + .Device(DEVICE_CPU), + GetTensorsWithMissingMaskOp); + +REGISTER_OP("GetSparseTensorsFromDataRecord") +.Attr("feature_id: int") +.Input("data_record_handle: resource") +.Output("ids: int64") +.Output("indices: string") +.Output("values: string") +.Output("dense_shape: int64") +.Output("values_type: string") +.Output("valueendian: uint8") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that decodes and returns tensors from the data record. + +Attr + feature_id: The hashed id of the feature name. + +Input + data_record_handle: Resource handle to DataRecord. + +Outputs + ids: A 1D tensor representing which input in the batch the value belongs to. + indices: An string tensor containing indices of the sparse tensor as bytes. + values: An string tensor containing values of the sparse tensor as bytes. + dense_shape: A tensor containing [batch_size, thrift_shape]. + values_type: The data type of value tensor returned as a string tensor of size 1. + values_endian: Endianness of the bytes returned a tensor of size 1. 0: litte, 1: big. +)doc"); + +REGISTER_OP("GetSparseTensorsFromHashedDataRecord") +.Attr("feature_id: int") +.Input("hashed_data_record_handle: resource") +.Output("ids: int64") +.Output("indices: string") +.Output("values: string") +.Output("dense_shape: int64") +.Output("values_type: string") +.Output("values_endian: uint8") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + return Status::OK(); + }).Doc(R"doc( +A tensorflow OP that decodes and returns tensors from the data record. + +Attr + feature_id: The hashed id of the feature name. + +Input + data_record_handle: Resource handle to DataRecord. + +Outputs + ids: A 1D tensor representing which input in the batch the value belongs to. + indices: An string tensor containing indices of the sparse tensor as bytes. + values: An string tensor containing values of the sparse tensor as bytes. + dense_shape: A tensor containing [batch_size, thrift_shape]. + values_type: The data type of value tensor returned as a string tensor of size 1. + values_endian: Endianness of the bytes returned a tensor of size 1. 0: litte, 1: big. +)doc"); + +template +class GetSparseTensorsOp : public OpKernel { + private: + int64 feature_id; + + public: + explicit GetSparseTensorsOp(OpKernelConstruction *context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("feature_id", &feature_id)); + } + + void Compute(OpKernelContext *context) override { + auto handle = getHandle(context, 0); + const int64 batch_size = static_cast(handle->records.size()); + const auto &records = handle->records; + + try { + twml_type type = TWML_TYPE_UNKNOWN; + bool is_big_endian = false; + + std::vector shape(1, batch_size); + + int64 total_length = 0; + std::vector lengths; + lengths.reserve(batch_size); + + int64 total_indices_length = 0; + std::vector indices_raw_lengths; + std::vector indices_data_ptrs; + indices_raw_lengths.reserve(batch_size); + indices_data_ptrs.reserve(batch_size); + + int64 total_values_length = 0; + std::vector values_raw_lengths; + std::vector values_data_ptrs; + values_raw_lengths.reserve(batch_size); + values_data_ptrs.reserve(batch_size); + + for (auto record : records) { + const twml::RawSparseTensor sparse_tensor = record.getRawSparseTensor(feature_id); + const twml::RawTensor indices = sparse_tensor.indices(); + const twml::RawTensor values = sparse_tensor.values(); + const auto &dense_shape = sparse_tensor.denseShape(); + const auto indices_type = indices.getType(); + const auto indices_is_big_endian = indices.is_big_endian(); + const auto values_type = values.getType(); + const bool values_is_big_endian = values.is_big_endian(); + + const uint64 indices_length = indices.getDims().back(); + const uint64 values_length = values.getDims().back(); + + auto indices_raw_length = indices.getRawLength(); + auto values_raw_length = values.getRawLength(); + + auto indices_data_ptr = reinterpret_cast(indices.getData()); + auto values_data_ptr = reinterpret_cast(values.getData()); + + indices_raw_lengths.push_back(indices_raw_length); + values_raw_lengths.push_back(values_raw_length); + + indices_data_ptrs.push_back(indices_data_ptr); + values_data_ptrs.push_back(values_data_ptr); + + total_indices_length += indices_raw_length; + total_values_length += values_raw_length; + + if (shape.size() == 1) { + shape.reserve(dense_shape.size() + 1); + shape.insert(shape.end(), dense_shape.begin(), dense_shape.end()); + type = values_type; + is_big_endian = values_is_big_endian; + } + + // Assert shape of all tensors is the same. + if (!std::equal(shape.begin() + 1, shape.end(), dense_shape.begin())) { + throw std::runtime_error("dense_shape of sparse tensors doesn't match for feature_id: " + + std::to_string(feature_id)); + } + // Assert type of all values tensor is the same. + if (type != values_type || is_big_endian != values_is_big_endian) { + throw std::runtime_error("The type of values do not match for feature_id: " + + std::to_string(feature_id)); + } + // Assert indices tensor is big endian and of type INT64. + if (indices_type != TWML_TYPE_INT64 || !indices_is_big_endian) { + throw std::runtime_error("Unexpected type for index tensor for feature_id: " + + std::to_string(feature_id)); + } + + if (indices_length != values_length) { + throw std::runtime_error("The length of values and indices does not match for : " + + std::to_string(feature_id)); + } + + lengths.push_back(indices_length); + total_length += indices_length; + } + + Tensor* ids_tensor = nullptr; + TensorShape ids_shape = {static_cast(total_length)}; + OP_REQUIRES_OK(context, context->allocate_output(0, ids_shape, &ids_tensor)); + auto ids_tensor_flat = ids_tensor->flat(); + auto ids_tensor_data = ids_tensor_flat.data(); + + TensorShape raw_shape = {static_cast(1)}; + + Tensor* indices_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(1, raw_shape, &indices_tensor)); + auto indices_tensor_flat = indices_tensor->flat(); + auto indices_tensor_string = indices_tensor_flat.data(); + indices_tensor_string->resize(total_indices_length); + auto indices_tensor_iter = indices_tensor_string->begin(); + + Tensor* values_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(2, raw_shape, &values_tensor)); + auto values_tensor_flat = values_tensor->flat(); + auto values_tensor_string = values_tensor_flat.data(); + values_tensor_string->resize(total_values_length); + auto values_tensor_iter = values_tensor_string->begin(); + + for (int64 i = 0; i < batch_size; i++) { + // Fill in the data for id == i for all values in the current input. + std::fill(ids_tensor_data, ids_tensor_data + lengths[i], i); + ids_tensor_data += lengths[i]; + + indices_tensor_iter = std::copy(indices_data_ptrs[i], + indices_data_ptrs[i] + indices_raw_lengths[i], + indices_tensor_iter); + + values_tensor_iter = std::copy(values_data_ptrs[i], + values_data_ptrs[i] + values_raw_lengths[i], + values_tensor_iter); + } + + Tensor *shape_tensor = nullptr; + TensorShape shape_shape = {static_cast(shape.size())}; + OP_REQUIRES_OK(context, context->allocate_output(3, shape_shape, &shape_tensor)); + auto shape_flat = shape_tensor->flat(); + for (int i = 0; i < static_cast(shape.size()); i++) { + shape_flat(i) = shape[i]; + } + + Tensor* type_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(4, {}, &type_tensor)); + type_tensor->scalar()() = twml::getTypeName(type); + + Tensor* endian_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(5, {}, &endian_tensor)); + endian_tensor->scalar()() = is_big_endian; + } catch(const std::exception &err) { + context->CtxFailureWithWarning(errors::InvalidArgument(err.what())); + } + } +}; + +REGISTER_KERNEL_BUILDER( + Name("GetSparseTensorsFromDataRecord") + .Device(DEVICE_CPU), + GetSparseTensorsOp); + +REGISTER_KERNEL_BUILDER( + Name("GetSparseTensorsFromHashedDataRecord") + .Device(DEVICE_CPU), + GetSparseTensorsOp); diff --git a/twml/libtwml/src/ops/tensorflow_utils.cpp b/twml/libtwml/src/ops/tensorflow_utils.cpp new file mode 100644 index 000000000..95ebc7e4c --- /dev/null +++ b/twml/libtwml/src/ops/tensorflow_utils.cpp @@ -0,0 +1,87 @@ +#include "tensorflow_utils.h" +#include +#include + +twml::Tensor TFTensor_to_twml_tensor(Tensor &input) { + int ndims = input.dims(); + std::vector dims(ndims); + std::vector strides(ndims); + for (int i = 0; i < ndims; i++) { + dims[i] = input.dim_size(i); + } + uint64_t stride = 1; + for (int i = ndims-1; i >= 0; i--) { + strides[i] = stride; + stride *= dims[i]; + } + + switch (input.dtype()) { + case DT_INT8: + return twml::Tensor(input.flat().data(), dims, strides, TWML_TYPE_INT8); + case DT_UINT8: + return twml::Tensor(input.flat().data(), dims, strides, TWML_TYPE_UINT8); + case DT_INT32: + return twml::Tensor(input.flat().data(), dims, strides, TWML_TYPE_INT32); + case DT_INT64: + return twml::Tensor(input.flat().data(), dims, strides, TWML_TYPE_INT64); + case DT_FLOAT: + return twml::Tensor(input.flat().data(), dims, strides, TWML_TYPE_FLOAT); + case DT_DOUBLE: + return twml::Tensor(input.flat().data(), dims, strides, TWML_TYPE_DOUBLE); + case DT_BOOL: + return twml::Tensor(input.flat().data(), dims, strides, TWML_TYPE_BOOL); + case DT_STRING: + return twml::Tensor(input.flat().data(), dims, strides, TWML_TYPE_STRING); + default: + throw twml::Error(TWML_ERR_TYPE, "Unknown tensor data type."); + break; + } +} + +const twml::Tensor TFTensor_to_twml_tensor(const Tensor &input) { + // TODO: define some type of constant tensor, which should be used for inputs to force not + // changing + return TFTensor_to_twml_tensor(const_cast(input)); +} + +twml::RawTensor TFTensor_to_twml_raw_tensor(Tensor &input) { + int ndims = input.dims(); + std::vector dims(ndims); + std::vector strides(ndims); + for (int i = 0; i < ndims; i++) { + dims[i] = input.dim_size(i); + } + uint64_t stride = 1; + for (int i = ndims-1; i >= 0; i--) { + strides[i] = stride; + stride *= dims[i]; + } + + switch (input.dtype()) { + case DT_INT8: + return twml::RawTensor(input.flat().data(), dims, strides, TWML_TYPE_INT8, false, input.flat().size()); + case DT_UINT8: + return twml::RawTensor(input.flat().data(), dims, strides, TWML_TYPE_UINT8, false, input.flat().size()); + case DT_INT32: + return twml::RawTensor(input.flat().data(), dims, strides, TWML_TYPE_INT32, false, input.flat().size()); + case DT_INT64: + return twml::RawTensor(input.flat().data(), dims, strides, TWML_TYPE_INT64, false, input.flat().size()); + case DT_FLOAT: + return twml::RawTensor(input.flat().data(), dims, strides, TWML_TYPE_FLOAT, false, input.flat().size()); + case DT_DOUBLE: + return twml::RawTensor(input.flat().data(), dims, strides, TWML_TYPE_DOUBLE, false, input.flat().size()); + case DT_BOOL: + return twml::RawTensor(input.flat().data(), dims, strides, TWML_TYPE_BOOL, false, input.flat().size()); + case DT_STRING: + return twml::RawTensor(input.flat().data(), dims, strides, TWML_TYPE_STRING, false, input.flat().size()); + default: + throw twml::Error(TWML_ERR_TYPE, "Unknown tensor data type."); + break; + } +} + +const twml::RawTensor TFTensor_to_twml_raw_tensor(const Tensor &input) { + // TODO: define some type of constant tensor, which should be used for inputs to force not + // changing + return TFTensor_to_twml_raw_tensor(const_cast(input)); +} diff --git a/twml/libtwml/src/ops/tensorflow_utils.h b/twml/libtwml/src/ops/tensorflow_utils.h new file mode 100644 index 000000000..4940f680d --- /dev/null +++ b/twml/libtwml/src/ops/tensorflow_utils.h @@ -0,0 +1,13 @@ +#pragma once + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include + +using namespace tensorflow; +twml::Tensor TFTensor_to_twml_tensor(Tensor &input); +twml::RawTensor TFTensor_to_twml_raw_tensor(Tensor &input); +const twml::Tensor TFTensor_to_twml_tensor(const Tensor &input); +const twml::RawTensor TFTensor_to_twml_raw_tensor(const Tensor &input); + diff --git a/twml/libtwml/src/ops/var_length_reader.cpp b/twml/libtwml/src/ops/var_length_reader.cpp new file mode 100644 index 000000000..62b5fc2a1 --- /dev/null +++ b/twml/libtwml/src/ops/var_length_reader.cpp @@ -0,0 +1,46 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" + +using namespace tensorflow; + +REGISTER_OP("VarLengthReader") +.Input("input1: int32") +.Output("output: int32") +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + ::tensorflow::shape_inference::ShapeHandle input; + // check that input has only 1 dimension. + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input)); + // there's no inference on output shape. + return Status::OK(); + }); + + +class VarLengthReaderOp : public OpKernel { + public: + explicit VarLengthReaderOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& input_tensor = context->input(0); + auto input = input_tensor.flat(); + + // get the first element in the input tensor, use it as output shape. + int32 len = input(0); + TensorShape output_shape = {1, len}; + + // Create an output tensor, the size is determined by the content of input. + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor)); + + auto output_flat = output_tensor->flat(); + + // Fill output with ones. + const int N = output_flat.size(); + for (int i = 0; i < N; i++) { + output_flat(i) = 1; + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("VarLengthReader").Device(DEVICE_CPU), VarLengthReaderOp); diff --git a/twml/setup.cfg b/twml/setup.cfg new file mode 100644 index 000000000..d887f33c2 --- /dev/null +++ b/twml/setup.cfg @@ -0,0 +1,8 @@ +[bdist_wheel] +universal=1 + +[build] +build-lib=build_dir + +[bdist] +bdist-base=build_dir diff --git a/twml/setup.py b/twml/setup.py new file mode 100644 index 000000000..7e4003bae --- /dev/null +++ b/twml/setup.py @@ -0,0 +1,29 @@ +import os + +from setuptools import find_packages, setup + + +THIS_DIR = os.path.dirname(os.path.realpath(__file__)) +TWML_TEST_DATA_DIR = os.path.join(THIS_DIR, 'twml/tests/data') + +data_files = [] +for parent, children, files in os.walk(TWML_TEST_DATA_DIR): + data_files += [os.path.join(parent, f) for f in files] + +setup( + name='twml', + version='2.0', + description="Tensorflow wrapper for twml", + packages=find_packages(exclude=["build"]), + install_requires=[ + 'thriftpy2', + 'numpy', + 'pyyaml', + 'future', + 'scikit-learn', + 'scipy' + ], + package_data={ + 'twml': data_files, + }, +) diff --git a/twml/twml/__init__.py b/twml/twml/__init__.py new file mode 100644 index 000000000..0c96df68b --- /dev/null +++ b/twml/twml/__init__.py @@ -0,0 +1,61 @@ +""" Importing the pyton op wrappers """ + +import os + +# Import from twitter.deepbird +from twitter.deepbird.logging.log_level import set_logging_level # noqa: F401 +from twitter.deepbird.sparse import SparseTensor # noqa: F401 +from twitter.deepbird.sparse import sparse_dense_matmul # noqa: F401 + +from .util import dynamic_partition, feature_id, limit_bits, limit_sparse_tensor_size # noqa: F401 +from .util import write_file, fixed_length_tensor, setup_tf_logging_formatter # noqa: F401 +from .array import Array # noqa: F401 + +# Module to parse feature patterns and match them from data_spec.json +from .feature_config import FeatureConfig, FeatureConfigBuilder # noqa: F401 + +# Data record streaming, reading, writing, and parsing. +from .dataset import * # noqa: T400 +from .readers import * # noqa: T400 +from .block_format_writer import * # noqa: T400 + +# Graph output functions +from .export_output_fns import * # noqa: T400 + +# Input parsers +from .parsers import * # noqa: T400 + +# Input functions +from .input_fns import * # noqa: T400 + +# Feature filter functions +from .filters import * # noqa: T400 + +# Custom argparser for Trainer +from .argument_parser import * # noqa: T400 + +from . import constants # noqa: F401 +from . import errors # noqa: F401 +from . import layers # noqa: F401 +from . import lookup # noqa: F401 +from . import readers # noqa: F401 +from . import summary # noqa: F401 +from . import tensorboard # noqa: F401 + +import tensorflow.compat.v1 as tf # noqa: F402 +tf.disable_eager_execution() + +# TODO: Figure out a better way to deal with this. +if 'OMP_NUM_THREADS' not in os.environ and 'MKL_NUM_THREADS' not in os.environ: + os.environ["OMP_NUM_THREADS"] = '1' + +# Import all custom C++ ops +from libtwml import add1, partition_sparse_tensor, CLIB # noqa: F401 + +# Configure logging levels to info for various frameworks +set_logging_level('INFO') + +from . import contrib # noqa: F401 +from . import hooks # noqa: F401 +from . import trainers # noqa: F401 +from . import metrics # noqa: F401 diff --git a/twml/twml/argument_parser.py b/twml/twml/argument_parser.py new file mode 100644 index 000000000..c771eebdf --- /dev/null +++ b/twml/twml/argument_parser.py @@ -0,0 +1,561 @@ +# pylint: disable=protected-access, arguments-differ +""" +Command-line argument parsing for the Trainer. +""" +import argparse +from argparse import ArgumentError +from operator import attrgetter +import tempfile + +import twml +import tensorflow.compat.v1 as tf + + +SERIAL = "serial" +TREE = "tree" +LOG_LEVELS = { + "debug": tf.logging.DEBUG, + "info": tf.logging.INFO, + "warn": tf.logging.WARN, + "error": tf.logging.ERROR} + + +class SortingHelpFormatter(argparse.HelpFormatter): + """ + Used to sort args alphabetically in the help message. + """ + + def add_arguments(self, actions): + actions = sorted(actions, key=attrgetter('option_strings')) + super(SortingHelpFormatter, self).add_arguments(actions) + + +def _set_log_level(level=None): + """Sets the tensorflow log level to the input level.""" + if level is None: + return None + level = level.lower() + if level not in LOG_LEVELS.keys(): + raise ValueError(f"Unexpected log level {level} was given but expected one of {LOG_LEVELS.keys()}.") + tf.logging.set_verbosity(LOG_LEVELS[level]) + tf.logging.info(f"Setting tensorflow logging level to {level} or {LOG_LEVELS[level]}") + return level + + +def get_trainer_parser(): + """ + Add common commandline args to parse for the Trainer class. + Typically, the user calls this function and then parses cmd-line arguments + into an argparse.Namespace object which is then passed to the Trainer constructor + via the params argument. + + See the `code <_modules/twml/argument_parser.html#get_trainer_parser>`_ + for a list and description of all cmd-line arguments. + + Args: + learning_rate_decay: + Defaults to False. When True, parses learning rate decay arguments. + + Returns: + argparse.ArgumentParser instance with some useful args already added. + """ + parser = twml.DefaultSubcommandArgParse(formatter_class=SortingHelpFormatter) + + parser.add_argument( + "--save_dir", type=str, default=tempfile.mkdtemp(), + help="Path to the training result directory." + "supports local filesystem path and hdfs://default/ which requires " + "setting HDFS configuration via env variable HADOOP_CONF_DIR ") + parser.add_argument( + "--export_dir", type=str, default=None, + help="Path to the directory to export a SavedModel for prediction servers.") + parser.add_argument( + "--log_aggregation_app_id", type=str, default=None, + help="specify app_id for log aggregation. disabled by default.") + parser.add_argument( + "--train.batch_size", "--train_batch_size", type=int, default=32, + dest='train_batch_size', + help="number of samples per training batch") + parser.add_argument( + "--eval.batch_size", "--eval_batch_size", type=int, default=32, + dest='eval_batch_size', + help="number of samples per cross-validation batch. Defaults to train_batch_size") + parser.add_argument( + "--train.learning_rate", "--learning_rate", type=float, default=0.002, + dest='learning_rate', + help="learning rate. Scales the gradient update.") + parser.add_argument( + "--train.steps", "--train_steps", type=int, default=-1, + dest='train_steps', + help="number of training batches before running evaluation." + "Defaults to -1 (runs through entire dataset). " + "Only used for Trainer.[train,learn]. " + "For Trainer.train_and_evaluate, use train.max_steps instead. ") + parser.add_argument( + "--eval.steps", "--eval_steps", type=int, default=-1, + dest="eval_steps", + help="number of steps per evaluation. Each batch is a step." + "Defaults to -1 (runs through entire dataset). ") + parser.add_argument( + "--eval.period", "--eval_period", type=int, default=600, + dest="eval_period", + help="Trainer.train_and_evaluate waits for this long after each evaluation. " + "Defaults to 600 seconds (evaluate every ten minutes). " + "Note that anything lower than 10*60seconds is probably a bad idea because TF saves " + "checkpoints every 10mins by default. eval.delay is time to wait before doing first eval. " + "eval.period is time between successive evals.") + parser.add_argument( + "--eval.delay", "--eval_delay", type=int, default=120, + dest="eval_delay", + help="Trainer.train_and_evaluate waits for this long before performing the first evaluation" + "Defaults to 120 seconds (evaluate after first 2 minutes of training). " + "eval.delay is time to wait before doing first eval. " + "eval.period is time between successive evals.") + parser.add_argument( + "--train.max_steps", "--train_max_steps", type=int, default=None, + dest="train_max_steps", + help="Stop training after this many global steps. Each training batch is its own step." + "If set to None, step after one train()/evaluate() call. Useful when train.steps=-1." + "If set to a non-positive value, loop forever. Usually useful with early stopping.") + parser.add_argument( + "--train.log_metrics", dest="train_log_metrics", action="store_true", default=False, + help="Set this to true to see metrics during training. " + "WARNING: metrics during training does not represent model performance. " + "WARNING: use for debugging only as this slows down training.") + parser.add_argument( + "--train.early_stop_patience", "--early_stop_patience", type=int, default=-1, + dest="early_stop_patience", + help="max number of evaluations (epochs) to wait for an improvement in the early_stop_metric." + "Defaults to -1 (no early-stopping)." + "NOTE: This can not be enabled when --distributed is also set.") + parser.add_argument( + "--train.early_stop_tolerance", "--early_stop_tolerance", type=float, default=0, + dest="early_stop_tolerance", + help="a non-negative tolerance for comparing early_stop_metric." + "e.g. when maximizing the condition is current_metric > best_metric + tolerance." + "Defaults to 0.") + parser.add_argument( + "--train.dataset_shards", "--train_dataset_shards", + dest="train_dataset_shards", + type=int, default=None, + help="An int value that indicates the number of partitions (shards) for the dataset. This is" + " useful for codistillation and other techniques that require each worker to train on disjoint" + " partitions of the dataset.") + parser.add_argument( + "--train.dataset_shard_index", "--train_dataset_shard_index", + dest="train_dataset_shard_index", + type=int, default=None, + help="An int value (starting at zero) that indicates which partition (shard) of the dataset" + " to use if --train.dataset_shards is set.") + parser.add_argument( + "--continue_from_checkpoint", dest="continue_from_checkpoint", action="store_true", + help="DEPRECATED. This option is currently a no-op." + " Continuing from the provided checkpoint is now the default." + " Use --overwrite_save_dir if you would like to override it instead" + " and restart training from scratch.") + parser.add_argument( + "--overwrite_save_dir", dest="overwrite_save_dir", action="store_true", + help="Delete the contents of the current save_dir if it exists") + parser.add_argument( + "--data_threads", "--num_threads", type=int, default=2, + dest="num_threads", + help="Number of threads to use for loading the dataset. " + "num_threads is deprecated and to be removed in future versions. Use data_threads.") + parser.add_argument( + "--max_duration", "--max_duration", type=float, default=None, + dest="max_duration", + help="Maximum duration (in secs) that training/validation will be allowed to run for before being automatically terminated.") + parser.add_argument( + "--num_workers", type=int, default=None, + help="Number of workers to use when training in hogwild manner on a single node.") + parser.add_argument( + "--distributed", dest="distributed", action="store_true", + help="Pass this flag to use train_and_evaluate to train in a distributed fashion" + "NOTE: You can not use early stopping when --distributed is enabled" + ) + parser.add_argument( + "--distributed_training_cleanup", + dest="distributed_training_cleanup", + action="store_true", + help="Set if using distributed training on GKE to stop TwitterSetDeployment" + "from continuing training upon restarts (will be deprecated once we migrate off" + "TwitterSetDeployment for distributed training on GKE)." + ) + parser.add_argument( + "--disable_auto_ps_shutdown", default=False, action="store_true", + help="Disable the functionality of automatically shutting down parameter server after " + "distributed training complete (either succeed or failed)." + ) + parser.add_argument( + "--disable_tensorboard", default=False, action="store_true", + help="Do not start the TensorBoard server." + ) + parser.add_argument( + "--tensorboard_port", type=int, default=None, + help="Port for tensorboard to run on. Ignored if --disable_tensorboard is set.") + parser.add_argument( + "--health_port", type=int, default=None, + help="Port to listen on for health-related endpoints (e.g. graceful shutdown)." + "Not user-facing as it is set automatically by the twml_cli." + ) + parser.add_argument( + "--stats_port", type=int, default=None, + help="Port to listen on for stats endpoints" + ) + parser.add_argument( + "--experiment_tracking_path", + dest="experiment_tracking_path", + type=str, default=None, + help="The tracking path of this experiment. Format: \ + user_name:project_name:experiment_name:run_name. The path is used to track and display \ + a record of this experiment on ML Dashboard. Note: this embedded experiment tracking is \ + disabled when the deprecated Model Repo TrackRun is used in your model config. ") + parser.add_argument( + "--disable_experiment_tracking", + dest="disable_experiment_tracking", + action="store_true", + help="Whether experiment tracking should be disabled.") + parser.add_argument( + "--config.save_checkpoints_secs", "--save_checkpoints_secs", type=int, default=600, + dest='save_checkpoints_secs', + help="Configures the tf.estimator.RunConfig.save_checkpoints_secs attribute. " + "Specifies how often checkpoints are saved in seconds. Defaults to 10*60 seconds.") + parser.add_argument( + "--config.keep_checkpoint_max", "--keep_checkpoint_max", type=int, default=20, + dest='keep_checkpoint_max', + help="Configures the tf.estimator.RunConfig.keep_checkpoint_max attribute. " + "Specifies how many checkpoints to keep. Defaults to 20.") + parser.add_argument( + "--config.tf_random_seed", "--tf_random_seed", type=int, default=None, + dest='tf_random_seed', + help="Configures the tf.estimator.RunConfig.tf_random_seed attribute. " + "Specifies the seed to use. Defaults to None.") + parser.add_argument( + "--optimizer", type=str, default='SGD', + help="Optimizer to use: SGD (Default), Adagrad, Adam, Ftrl, Momentum, RMSProp, LazyAdam, DGC.") + parser.add_argument( + "--gradient_noise_scale", type=float, default=None, + help="adds 0-mean normal noise scaled by this value. Defaults to None.") + parser.add_argument( + "--clip_gradients", type=float, default=None, + help="If specified, a global clipping is applied to prevent " + "the norm of the gradient to exceed this value. Defaults to None.") + parser.add_argument( + "--dgc.density", "--dgc_density", type=float, default=0.1, + dest="dgc_density", + help="Specifies gradient density level when using deep gradient compression optimizer." + "E.g., default value being 0.1 means that only top 10%% most significant rows " + "(based on absolute value sums) are kept." + ) + parser.add_argument( + "--dgc.density_decay", "--dgc_density_decay", type=bool, default=True, + dest="dgc_density_decay", + help="Specifies whether to (exponentially) decay the gradient density level when" + " doing gradient compression. If set 'False', the 'density_decay_steps', " + "'density_decay_rate' and 'min_density' arguments will be ignored." + ) + parser.add_argument( + "--dgc.density_decay_steps", "--dgc_density_decay_steps", type=int, default=10000, + dest="dgc_density_decay_steps", + help="Specifies the step interval to perform density decay." + ) + parser.add_argument( + "--dgc.density_decay_rate", "--dgc_density_decay_rate", type=float, default=0.5, + dest="dgc_density_decay_rate", + help="Specifies the decay rate when perfoming density decay." + ) + parser.add_argument( + "--dgc.min_density", "--dgc_min_density", type=float, default=0.1, + dest="dgc_min_density", + help="Specifies the minimum density level when perfoming density decay." + ) + parser.add_argument( + "--dgc.accumulation", "--dgc_accumulation", type=bool, default=False, + dest="dgc_accumulation", + help="Specifies whether to accumulate small gradients when using deep gradient compression " + "optimizer." + ) + parser.add_argument( + "--show_optimizer_summaries", dest="show_optimizer_summaries", action="store_true", + help="When specified, displays gradients and learning rate in tensorboard." + "Turning it on has 10-20%% performance hit. Enable for debugging only") + + parser.add_argument( + "--num_mkl_threads", dest="num_mkl_threads", default=1, type=int, + help="Specifies how many threads to use for MKL" + "inter_op_ parallelism_threds is set to TWML_NUM_CPUS / num_mkl_threads." + "intra_op_parallelism_threads is set to num_mkl_threads.") + + parser.add_argument("--verbosity", type=_set_log_level, choices=LOG_LEVELS.keys(), default=None, + help="Sets log level to a given verbosity.") + + parser.add_argument( + "--feature_importance.algorithm", dest="feature_importance_algorithm", + type=str, default=TREE, choices=[SERIAL, TREE], + help=""" + There are two algorithms that the module supports, `serial` and `tree`. + The `serial` algorithm computes feature importances for each feature, and + the `tree` algorithm groups features by feature name prefix, computes feature + importances for groups of features, and then only 'zooms-in' on a group when the + importance is greater than the `--feature_importance.sensitivity` value. The `tree` algorithm + will usually run faster, but for relatively unimportant features it will only compute an + upper bound rather than an exact importance value. We suggest that users generally stick + to the `tree` algorithm, unless if they have a very small number of features or + near-random model performance. + """) + + parser.add_argument( + "--feature_importance.sensitivity", dest="feature_importance_sensitivity", type=float, default=0.03, + help=""" + The maximum amount that permuting a feature group can cause the model performance (determined + by `feature_importance.metric`) to drop before the algorithm decides to not expand the feature + group. This is only used for the `tree` algorithm. + """) + + parser.add_argument( + "--feature_importance.dont_build_tree", dest="dont_build_tree", action="store_true", default=False, + help=""" + If True, don't build the feature trie for the tree algorithm and only use the extra_groups + """) + + parser.add_argument( + "--feature_importance.split_feature_group_on_period", dest="split_feature_group_on_period", action="store_true", default=False, + help="If true, split feature groups by the period rather than the optimal prefix. Only used for the TREE algorithm") + + parser.add_argument( + "--feature_importance.example_count", dest="feature_importance_example_count", type=int, default=10000, + help=""" + The number of examples used to compute feature importance. + Larger values yield more reliable results, but also take longer to compute. + These records are loaded into memory. This number is agnostic to batch size. + """) + + parser.add_argument( + "--feature_importance.data_dir", dest="feature_importance_data_dir", type=str, default=None, + help="Path to the dataset used to compute feature importance." + "supports local filesystem path and hdfs://default/ which requires " + "setting HDFS configuration via env variable HADOOP_CONF_DIR " + "Defaults to eval_data_dir") + + parser.add_argument( + "--feature_importance.metric", dest="feature_importance_metric", type=str, default="roc_auc", + help="The metric used to determine when to stop expanding the feature importance tree. This is only used for the `tree` algorithm.") + + parser.add_argument( + "--feature_importance.is_metric_larger_the_better", dest="feature_importance_is_metric_larger_the_better", action="store_true", default=False, + help="If true, interpret `--feature_importance.metric` to be a metric where larger values are better (e.g. ROC_AUC)") + + parser.add_argument( + "--feature_importance.is_metric_smaller_the_better", dest="feature_importance_is_metric_smaller_the_better", action="store_true", default=False, + help="If true, interpret `--feature_importance.metric` to be a metric where smaller values are better (e.g. LOSS)") + + subparsers = parser.add_subparsers(help='Learning Rate Decay Functions. Can only pass 1.' + 'Should be specified after all the optional arguments' + 'and followed by its specific args' + 'e.g. --learning_rate 0.01 inverse_learning_rate_decay_fn' + ' --decay_rate 0.0004 --min_learning_rate 0.001', + dest='learning_rate_decay') + + # Create the parser for the "exponential_learning_rate_decay_fn" + parser_exponential = subparsers.add_parser('exponential_learning_rate_decay', + help='Exponential learning rate decay. ' + 'Exponential decay implements:' + 'decayed_learning_rate = learning_rate * ' + 'exponential_decay_rate ^ ' + '(global_step / decay_steps') + parser_exponential.add_argument( + "--decay_steps", type=float, default=None, + help="Required for 'exponential' learning_rate_decay.") + parser_exponential.add_argument( + "--exponential_decay_rate", type=float, default=None, + help="Required for 'exponential' learning_rate_decay. Must be positive. ") + + # Create the parser for the "polynomial_learning_rate_decay_fn" + parser_polynomial = subparsers.add_parser('polynomial_learning_rate_decay', + help='Polynomial learning rate decay. ' + 'Polynomial decay implements: ' + 'global_step = min(global_step, decay_steps)' + 'decayed_learning_rate = ' + '(learning_rate - end_learning_rate) * ' + '(1 - global_step / decay_steps) ^ ' + '(polynomial_power) + end_learning_rate' + 'So for linear decay you can use a ' + 'polynomial_power=1 (the default)') + parser_polynomial.add_argument( + "--end_learning_rate", type=float, default=0.0001, + help="Required for 'polynomial' learning_rate_decay (ignored otherwise).") + parser_polynomial.add_argument( + "--polynomial_power", type=float, default=0.0001, + help="Required for 'polynomial' learning_rate_decay." + "The power of the polynomial. Defaults to linear, 1.0.") + parser_polynomial.add_argument( + "--decay_steps", type=float, default=None, + help="Required for 'polynomial' learning_rate_decay. ") + + # Create the parser for the "piecewise_constant_learning_rate_decay_fn" + parser_piecewise_constant = subparsers.add_parser('piecewise_constant_learning_rate_decay', + help='Piecewise Constant ' + 'learning rate decay. ' + 'For piecewise_constant, ' + 'consider this example: ' + 'We want to use a learning rate ' + 'that is 1.0 for' + 'the first 100000 steps,' + '0.5 for steps 100001 to 110000, ' + 'and 0.1 for any additional steps. ' + 'To do so, specify ' + '--piecewise_constant_boundaries=100000,110000' + '--piecewise_constant_values=1.0,0.5,0.1') + parser_piecewise_constant.add_argument( + "--piecewise_constant_values", + action=parse_comma_separated_list(element_type=float), + default=None, + help="Required for 'piecewise_constant_values' learning_rate_decay. " + "A list of comma seperated floats or ints that specifies the values " + "for the intervals defined by boundaries. It should have one more " + "element than boundaries.") + parser_piecewise_constant.add_argument( + "--piecewise_constant_boundaries", + action=parse_comma_separated_list(element_type=int), + default=None, + help="Required for 'piecewise_constant_values' learning_rate_decay. " + "A list of comma seperated integers, with strictly increasing entries.") + + # Create the parser for the "inverse_learning_rate_decay_fn" + parser_inverse = subparsers.add_parser('inverse_learning_rate_decay', + help='Inverse Leaning rate decay. ' + 'Inverse implements:' + 'decayed_lr = max(lr /(1 + decay_rate * ' + 'floor(global_step /decay_step)),' + ' min_learning_rate)' + 'When decay_step=1 this mimics the behaviour' + 'of the default learning rate decay' + 'of DeepBird v1.') + + parser_inverse.add_argument( + "--decay_rate", type=float, default=None, + help="Required for 'inverse' learning_rate_decay. Rate in which we decay the learning rate.") + parser_inverse.add_argument( + "--min_learning_rate", type=float, default=None, + help="Required for 'inverse' learning_rate_decay.Minimum possible learning_rate.") + parser_inverse.add_argument( + "--decay_steps", type=float, default=1, + help="Required for 'inverse' learning_rate_decay.") + + # Create the parser for the "cosine_learning_rate_decay_fn" + parser_cosine = subparsers.add_parser('cosine_learning_rate_decay', + help='Cosine Leaning rate decay. ' + 'Cosine implements:' + 'decayed_lr = 0.5 * (1 + cos(pi *\ + global_step / decay_steps)) * lr' + ) + + parser_cosine.add_argument( + "--alpha", type=float, default=0, + help="A scalar float32 or float64 Tensor or a Python number.\ + Minimum learning rate value as a fraction of learning_rate.") + parser_cosine.add_argument( + "--decay_steps", type=float, + help="Required for 'inverse' learning_rate_decay.") + + # Create the parser for the "cosine_restart_learning_rate_decay_fn" + parser_cosine_restart = subparsers.add_parser('cosine_restarts_learning_rate_decay', + help='Applies cosine decay with restarts \ + to the learning rate' + 'See [Loshchilov & Hutter, ICLR2016],\ + SGDR: Stochastic' + 'Gradient Descent with Warm Restarts.' + 'https://arxiv.org/abs/1608.03983' + ) + parser_cosine_restart.add_argument( + "--first_decay_steps", type=float, + help="Required for 'cosine_restart' learning_rate_decay.") + parser_cosine_restart.add_argument( + "--alpha", type=float, default=0, + help="A scalar float32 or float64 Tensor or a Python number. \ + Minimum learning rate value as a fraction of learning_rate.") + parser_cosine_restart.add_argument( + "--t_mul", type=float, default=2, + help="A scalar float32 or float64 Tensor or a Python number. \ + Used to derive the number of iterations in the i-th period") + parser_cosine_restart.add_argument( + "--m_mul", type=float, default=1, + help="A scalar float32 or float64 Tensor or a Python number. \ + Used to derive the initial learning rate of the i-th period.") + + # Create dummy parser for None, which is the default. + parser_default = subparsers.add_parser( + 'no_learning_rate_decay', + help='No learning rate decay') # noqa: F841 + + parser.set_default_subparser('no_learning_rate_decay') + + return parser + + +class DefaultSubcommandArgParse(argparse.ArgumentParser): + """ + Subclass of argparse.ArgumentParser that sets default parser + """ + _DEFAULT_SUBPARSER = None + + def set_default_subparser(self, name): + """ + sets the default subparser + """ + self._DEFAULT_SUBPARSER = name + + def _parse_known_args(self, arg_strings, *args, **kwargs): + """ + Overwrites _parse_known_args + """ + in_args = set(arg_strings) + d_sp = self._DEFAULT_SUBPARSER + if d_sp is not None and not {'-h', '--help'}.intersection(in_args): + for x_val in self._subparsers._actions: + subparser_found = ( + isinstance(x_val, argparse._SubParsersAction) and + in_args.intersection(x_val._name_parser_map.keys()) + ) + if subparser_found: + break + else: + # insert default in first position, this implies no + # global options without a sub_parsers specified + arg_strings = arg_strings + [d_sp] + return super(DefaultSubcommandArgParse, self)._parse_known_args( + arg_strings, *args, **kwargs + ) + + def _check_value(self, action, value): + try: + super(DefaultSubcommandArgParse, self)._check_value( + action, value + ) + except ArgumentError as error: + error.message += ("\nERROR: Deepbird is trying to interpret \"{}\" as a value of {}. If this is not what you expected, " + "then most likely one of the following two things are happening: Either one of your cli arguments are not recognized, " + "probably {} or whichever argument you are passing {} as a value to OR you are passing in an argument after " + "the `learning_rate_decay` argument.\n").format(value, action.dest, value, value) + raise error + + +def parse_comma_separated_list(element_type=str): + """ + Generates an argparse.Action that converts a string representing a comma separated list to a + list and converts each element to a specified type. + """ + + # pylint: disable-msg=too-few-public-methods + class _ParseCommaSeparatedList(argparse.Action): + """ + Converts a string representing a comma separated list to a list and converts each element to a + specified type. + """ + + def __call__(self, parser, namespace, values, option_string=None): + if values is not None: + values = [element_type(v) for v in values.split(',')] + setattr(namespace, self.dest, values) + + return _ParseCommaSeparatedList diff --git a/twml/twml/array.py b/twml/twml/array.py new file mode 100644 index 000000000..a8524a06d --- /dev/null +++ b/twml/twml/array.py @@ -0,0 +1,101 @@ +"""Module containing wrapper class to allow numpy arrays to work with twml functions""" + +import ctypes as ct + +from absl import logging +from libtwml import CLIB +import numpy as np + + +_NP_TO_TWML_TYPE = { + 'float32': ct.c_int(1), + 'float64': ct.c_int(2), + 'int32': ct.c_int(3), + 'int64': ct.c_int(4), + 'int8': ct.c_int(5), + 'uint8': ct.c_int(6), +} + + +class Array(object): + """ + Wrapper class to allow numpy arrays to work with twml functions. + """ + + def __init__(self, array): + """ + Wraps numpy array and creates a handle that can be passed to C functions from libtwml. + + array: Numpy array + """ + if not isinstance(array, np.ndarray): + raise TypeError("Input must be a numpy array") + + try: + ttype = _NP_TO_TWML_TYPE[array.dtype.name] + except KeyError as err: + logging.error("Unsupported numpy type") + raise err + + handle = ct.c_void_p(0) + ndim = ct.c_int(array.ndim) + dims = array.ctypes.get_shape() + isize = array.dtype.itemsize + + strides_t = ct.c_size_t * array.ndim + strides = strides_t(*[n // isize for n in array.strides]) + + err = CLIB.twml_tensor_create(ct.pointer(handle), + array.ctypes.get_as_parameter(), + ndim, dims, strides, ttype) + + if err != 1000: + raise RuntimeError("Error from libtwml") + + # Store the numpy array to ensure it isn't deleted before self + self._array = array + + self._handle = handle + + self._type = ttype + + @property + def handle(self): + """ + Return the twml handle + """ + return self._handle + + @property + def shape(self): + """ + Return the shape + """ + return self._array.shape + + @property + def ndim(self): + """ + Return the shape + """ + return self._array.ndim + + @property + def array(self): + """ + Return the numpy array + """ + return self._array + + @property + def dtype(self): + """ + Return numpy dtype + """ + return self._array.dtype + + def __del__(self): + """ + Delete the handle + """ + CLIB.twml_tensor_delete(self._handle) diff --git a/twml/twml/block_format_writer.py b/twml/twml/block_format_writer.py new file mode 100644 index 000000000..9c4a9b6a8 --- /dev/null +++ b/twml/twml/block_format_writer.py @@ -0,0 +1,65 @@ +"""Module containing wrapper class to write block format data""" +import ctypes as ct + +from libtwml import CLIB + + +class BlockFormatWriter(object): + """ + Class to write block format file. + """ + + def __init__(self, file_name, records_per_block=100): + file_name = file_name + if not isinstance(file_name, str): + raise ValueError("file_name has to be of type str") + + self.file_name = ct.c_char_p(file_name.encode()) + self.records_per_block = ct.c_int(int(records_per_block)) + handle = ct.c_void_p(0) + err = CLIB.block_format_writer_create(ct.pointer(handle), + self.file_name, + self.records_per_block) + self._handle = None + # 1000 means TWML_ERR_NONE + if err != 1000: + raise RuntimeError("Error from libtwml") + self._handle = handle + + @property + def handle(self): + """ + Return the handle + """ + return self._handle + + def write(self, class_name, record): + """ + Write a record. + + Note: `record` needs to be in a format that can be converted to ctypes.c_char_p. + """ + if not isinstance(class_name, str): + raise ValueError("class_name has to be of type str") + + record_len = len(record) + class_name = ct.c_char_p(class_name.encode()) + record = ct.c_char_p(record) + err = CLIB.block_format_write(self._handle, class_name, record, record_len) + if err != 1000: + raise RuntimeError("Error from libtwml") + + def flush(self): + """ + Flush records in buffer to outputfile. + """ + err = CLIB.block_format_flush(self._handle) + if err != 1000: + raise RuntimeError("Error from libtwml") + + def __del__(self): + """ + Delete the handle + """ + if self._handle: + CLIB.block_format_writer_delete(self._handle) diff --git a/twml/twml/constants.py b/twml/twml/constants.py new file mode 100644 index 000000000..c6c726eed --- /dev/null +++ b/twml/twml/constants.py @@ -0,0 +1,11 @@ +# These should coincide with 'enum class DecodeMode' values in HashedDataRecordReader.h + +from twitter.deepbird.io.legacy.constants import ( + DECODE_MODES, # noqa: F401 + DEFAULT_DECODE_MODE, # noqa: F401 + HASH_FNAME_AND_VALNAME, # noqa: F401 + HASH_VALNAME, # noqa: F401 + HashingDiscretizerOptions, # noqa: F401 + DEFAULT_ZOOKEEPER_BASE_ZNODE, # noqa: F401 + DEFAULT_ZOOKEEPER_HOST, # noqa: F401 +) diff --git a/twml/twml/contrib/__init__.py b/twml/twml/contrib/__init__.py new file mode 100644 index 000000000..1a5e8efe4 --- /dev/null +++ b/twml/twml/contrib/__init__.py @@ -0,0 +1,21 @@ +# pylint: disable=wildcard-import +""" experimental and contributed modules """ + +from . import layers # noqa: F401 +from . import feature_importances # noqa: F401 +from . import calibrators # noqa: F401 +from . import readers # noqa: F401 +from . import utils # noqa: F401 +from . import build_graphs_fns # noqa: F401 +from . import feature_config # noqa: F401 +from . import parsers # noqa: F401 +from . import initializers # noqa: F401 +from . import export # noqa: F401 +from . import feature_config_parsers # noqa: F401 + +# These imports do not work with TF 2.x and are not needed either. +# If you are using TF 2.x, use the modular targets under src/python/twitter/deepbird. +import tensorflow +from . import trainers # noqa: F401 +from . import metrics # noqa: F401 +from . import hooks # noqa: F401 diff --git a/twml/twml/contrib/build_graphs_fns.py b/twml/twml/contrib/build_graphs_fns.py new file mode 100644 index 000000000..829f61512 --- /dev/null +++ b/twml/twml/contrib/build_graphs_fns.py @@ -0,0 +1,32 @@ +# pylint: disable=unused-argument, missing-docstring +''' +Common build graphs that can be reused +''' +import tensorflow.compat.v1 as tf + + +def get_saved_modules_graph(input_graph_fn): + """ + Get common graph for stitching different saved modules for export. + This graph is used to save checkpoints; and then export the modules + as a unity. + Args: + features: + model features + params: + model params + input_graph_fn: + main logic for the stitching + Returns: + build_graph + """ + def build_graph(features, label, mode, params, config=None): + output = input_graph_fn(features, params) + # If mode is train, we just need to assign a dummy loss + # and update the train op. This is done to save the graph to save_dir. + if mode == 'train': + loss = tf.constant(1) + train_op = tf.assign_add(tf.train.get_global_step(), 1) + return {'train_op': train_op, 'loss': loss} + return output + return build_graph diff --git a/twml/twml/contrib/calibrators/__init__.py b/twml/twml/contrib/calibrators/__init__.py new file mode 100644 index 000000000..02181ed12 --- /dev/null +++ b/twml/twml/contrib/calibrators/__init__.py @@ -0,0 +1,18 @@ +# pylint: disable=wildcard-import +""" +This module contains classes used for calibration. +Typically, each calibrator defines a ``twml.calibrator.Calibrator`` subclass +and a ``twml.calibrator.CalibrationFeature``. +The latter manages weights and values of individual features. +The former manages a set of ``CalibratorFeatures`` +(although some ``Calibrators`` don't use ``CalibrationFeature``). +Ultimately, the ``Calibrator`` should produce an initialized layer via its ``to_layer()`` method. +""" + +from .common_calibrators import calibrate_discretizer_and_export, add_discretizer_arguments # noqa: F401 +from .calibrator import Calibrator # noqa: F401 +from .mdl import MDLCalibrator # noqa: F401 +from .isotonic import IsotonicCalibrator # noqa: F401 +from .percentile_discretizer import PercentileDiscretizerCalibrator # noqa: F401 +from .hashed_percentile_discretizer import HashedPercentileDiscretizerCalibrator # noqa: F401 +from .hashing_discretizer import HashingDiscretizerCalibrator # noqa: F401 \ No newline at end of file diff --git a/twml/twml/contrib/calibrators/calibrator.py b/twml/twml/contrib/calibrators/calibrator.py new file mode 100644 index 000000000..7408412e0 --- /dev/null +++ b/twml/twml/contrib/calibrators/calibrator.py @@ -0,0 +1,157 @@ +# pylint: disable=missing-docstring, unused-argument +''' Contains the base classes for CalibrationFeature and Calibrator ''' + + +from collections import defaultdict + +import numpy as np +import tensorflow.compat.v1 as tf +import tensorflow_hub as hub +import twml +import twml.util + + +class CalibrationFeature(object): + ''' + Accumulates values and weights for individual features. + Typically, each unique feature defined in the accumulated SparseTensor or Tensor + would have its own CalibrationFeature instance. + ''' + + def __init__(self, feature_id): + ''' Constructs a CalibrationFeature + + Arguments: + feature_id: + number identifying the feature. + ''' + self.feature_id = feature_id + self._calibrated = False + self._features_dict = defaultdict(list) + + def add_values(self, new_features): + ''' + Extends lists to contain the values in this batch + ''' + for key in new_features: + self._features_dict[key].append(new_features[key]) + + def _concat_arrays(self): + ''' + This class calls this function after you have added all the values. + It creates a dictionary with the concatanated arrays + ''' + self._features_dict.update((k, np.concatenate(v)) for k, v in self._features_dict.items()) + + def calibrate(self, *args, **kwargs): + raise NotImplementedError + + +class Calibrator(object): + ''' + Accumulates features and their respective values for Calibration + The steps for calibration are typically as follows: + + 1. accumulate feature values from batches by calling ``accumulate()`` and; + 2. calibrate by calling ``calibrate()``; + 3. convert to a twml.layers layer by calling ``to_layer()``. + + Note you can only use one calibrator per Trainer. + ''' + + def __init__(self, calibrator_name=None, **kwargs): + ''' + Arguments: + calibrator_name. + Default: if set to None it will be the same as the class name. + Please be reminded that if in the model there are many calibrators + of the same type the calibrator_name should be changed to avoid confusion. + ''' + self._calibrated = False + if calibrator_name is None: + calibrator_name = twml.util.to_snake_case(self.__class__.__name__) + self._calibrator_name = calibrator_name + self._kwargs = kwargs + + @property + def is_calibrated(self): + return self._calibrated + + @property + def name(self): + return self._calibrator_name + + def accumulate(self, *args, **kwargs): + '''Accumulates features and their respective values for Calibration.''' + raise NotImplementedError + + def calibrate(self): + '''Calibrates after the accumulation has ended.''' + self._calibrated = True + + def to_layer(self, name=None): + ''' + Returns a twml.layers.Layer instance with the result of calibrator. + + Arguments: + name: + name-scope of the layer + ''' + raise NotImplementedError + + def get_layer_args(self): + ''' + Returns layer arguments required to implement multi-phase training. + + Returns: + dictionary of Layer constructor arguments to initialize the + layer Variables. Typically, this should contain enough information + to initialize empty layer Variables of the correct size, which will then + be filled with the right data using init_map. + ''' + raise NotImplementedError + + def save(self, save_dir, name="default", verbose=False): + '''Save the calibrator into the given save_directory. + Arguments: + save_dir: + name of the saving directory. Default (string): "default". + name: + name for the calibrator. + ''' + if not self._calibrated: + raise RuntimeError("Expecting prior call to calibrate().Cannot save() prior to calibrate()") + + # This module allows for the calibrator to save be saved as part of + # Tensorflow Hub (this will allow it to be used in further steps) + def calibrator_module(): + # Note that this is usually expecting a sparse_placeholder + inputs = tf.sparse_placeholder(tf.float32) + calibrator_layer = self.to_layer() + output = calibrator_layer(inputs) + # creates the signature to the calibrator module + hub.add_signature(inputs=inputs, outputs=output, name=name) + + # exports the module to the save_dir + spec = hub.create_module_spec(calibrator_module) + with tf.Graph().as_default(): + module = hub.Module(spec) + with tf.Session() as session: + module.export(save_dir, session) + + def write_summary(self, writer, sess=None): + """ + This method is called by save() to write tensorboard summaries to disk. + See MDLCalibrator.write_summary for an example. + By default, the method does nothing. It can be overloaded by child-classes. + + Arguments: + writer: + `tf.summary.FilteWriter + `_ + instance. + The ``writer`` is used to add summaries to event files for inclusion in tensorboard. + sess (optional): + `tf.Session `_ + instance. The ``sess`` is used to produces summaries for the writer. + """ diff --git a/twml/twml/contrib/calibrators/common_calibrators.py b/twml/twml/contrib/calibrators/common_calibrators.py new file mode 100644 index 000000000..5301901e4 --- /dev/null +++ b/twml/twml/contrib/calibrators/common_calibrators.py @@ -0,0 +1,707 @@ +# pylint: disable=invalid-name, no-member, unused-argument +""" +This module contains common calibrate and export functions for calibrators. +""" + +# These 3 TODO are encapsulated by CX-11446 +# TODO: many of these functions hardcode datarecords yet don't allow passing a parse_fn. +# TODO: provide more generic (non DataRecord specific) functions +# TODO: many of these functions aren't common at all. +# For example, Discretizer functions should be moved to PercentileDiscretizer. + +import copy +import os +import time + +from absl import logging +import tensorflow.compat.v1 as tf +import tensorflow_hub as hub +import twml +from twml.argument_parser import SortingHelpFormatter +from twml.input_fns import data_record_input_fn +from twml.util import list_files_by_datetime, sanitize_hdfs_path +from twml.contrib.calibrators.isotonic import IsotonicCalibrator + + +def calibrator_arguments(parser): + """ + Calibrator Parameters to add to relevant parameters to the DataRecordTrainerParser. + Otherwise, if alone in a file, it just creates its own default parser. + Arguments: + parser: + Parser with the options to the model + """ + parser.add_argument("--calibrator.save_dir", type=str, + dest="calibrator_save_dir", + help="Path to save or load calibrator calibration") + parser.add_argument("--calibrator_batch_size", type=int, default=128, + dest="calibrator_batch_size", + help="calibrator batch size") + parser.add_argument("--calibrator_parts_downsampling_rate", type=float, default=1, + dest="calibrator_parts_downsampling_rate", + help="Parts downsampling rate") + parser.add_argument("--calibrator_max_steps", type=int, default=None, + dest="calibrator_max_steps", + help="Max Steps taken by calibrator to accumulate samples") + parser.add_argument("--calibrator_num_bins", type=int, default=22, + dest="calibrator_num_bins", + help="Num bins of calibrator") + parser.add_argument("--isotonic_calibrator", dest='isotonic_calibrator', action='store_true', + help="Isotonic Calibrator present") + parser.add_argument("--calibrator_keep_rate", type=float, default=1.0, + dest="calibrator_keep_rate", + help="Keep rate") + return parser + + +def _generate_files_by_datetime(params): + + files = list_files_by_datetime( + base_path=sanitize_hdfs_path(params.train_data_dir), + start_datetime=params.train_start_datetime, + end_datetime=params.train_end_datetime, + datetime_prefix_format=params.datetime_format, + extension="lzo", + parallelism=1, + hour_resolution=params.hour_resolution, + sort=True) + + return files + + +def get_calibrate_input_fn(parse_fn, params): + """ + Default input function used for the calibrator. + Arguments: + parse_fn: + Parse_fn + params: + Parameters + Returns: + input_fn + """ + + return lambda: data_record_input_fn( + files=_generate_files_by_datetime(params), + batch_size=params.calibrator_batch_size, + parse_fn=parse_fn, + num_threads=1, + repeat=False, + keep_rate=params.calibrator_keep_rate, + parts_downsampling_rate=params.calibrator_parts_downsampling_rate, + shards=None, + shard_index=None, + shuffle=True, + shuffle_files=True, + interleave=True) + + +def get_discretize_input_fn(parse_fn, params): + """ + Default input function used for the calibrator. + Arguments: + parse_fn: + Parse_fn + params: + Parameters + Returns: + input_fn + """ + + return lambda: data_record_input_fn( + files=_generate_files_by_datetime(params), + batch_size=params.discretizer_batch_size, + parse_fn=parse_fn, + num_threads=1, + repeat=False, + keep_rate=params.discretizer_keep_rate, + parts_downsampling_rate=params.discretizer_parts_downsampling_rate, + shards=None, + shard_index=None, + shuffle=True, + shuffle_files=True, + interleave=True) + + +def discretizer_arguments(parser=None): + """ + Discretizer Parameters to add to relevant parameters to the DataRecordTrainerParser. + Otherwise, if alone in a file, it just creates its own default parser. + Arguments: + parser: + Parser with the options to the model. Defaults to None + """ + + if parser is None: + parser = twml.DefaultSubcommandArgParse(formatter_class=SortingHelpFormatter) + parser.add_argument( + "--overwrite_save_dir", dest="overwrite_save_dir", action="store_true", + help="Delete the contents of the current save_dir if it exists") + parser.add_argument( + "--train.data_dir", "--train_data_dir", type=str, default=None, + dest="train_data_dir", + help="Path to the training data directory." + "Supports local and HDFS (hdfs://default/ ) paths.") + parser.add_argument( + "--train.start_date", "--train_start_datetime", + type=str, default=None, + dest="train_start_datetime", + help="Starting date for training inside the train data dir." + "The start datetime is inclusive." + "e.g. 2019/01/15") + parser.add_argument( + "--train.end_date", "--train_end_datetime", type=str, default=None, + dest="train_end_datetime", + help="Ending date for training inside the train data dir." + "The end datetime is inclusive." + "e.g. 2019/01/15") + parser.add_argument( + "--datetime_format", type=str, default="%Y/%m/%d", + help="Date format for training and evaluation datasets." + "Has to be a format that is understood by python datetime." + "e.g. %Y/%m/%d for 2019/01/15." + "Used only if {train/eval}.{start/end}_date are provided.") + parser.add_argument( + "--hour_resolution", type=int, default=None, + help="Specify the hourly resolution of the stored data.") + parser.add_argument( + "--tensorboard_port", type=int, default=None, + help="Port for tensorboard to run on.") + parser.add_argument( + "--stats_port", type=int, default=None, + help="Port for stats server to run on.") + parser.add_argument( + "--health_port", type=int, default=None, + help="Port to listen on for health-related endpoints (e.g. graceful shutdown)." + "Not user-facing as it is set automatically by the twml_cli." + ) + parser.add_argument( + "--data_spec", type=str, default=None, + help="Path to data specification JSON file. This file is used to decode DataRecords") + parser.add_argument("--discretizer.save_dir", type=str, + dest="discretizer_save_dir", + help="Path to save or load discretizer calibration") + parser.add_argument("--discretizer_batch_size", type=int, default=128, + dest="discretizer_batch_size", + help="Discretizer batch size") + parser.add_argument("--discretizer_keep_rate", type=float, default=0.0008, + dest="discretizer_keep_rate", + help="Keep rate") + parser.add_argument("--discretizer_parts_downsampling_rate", type=float, default=0.2, + dest="discretizer_parts_downsampling_rate", + help="Parts downsampling rate") + parser.add_argument("--discretizer_max_steps", type=int, default=None, + dest="discretizer_max_steps", + help="Max Steps taken by discretizer to accumulate samples") + return parser + + +def calibrate(trainer, params, build_graph, input_fn, debug=False): + """ + Calibrate Isotonic Calibration + Arguments: + trainer: + Trainer + params: + Parameters + build_graph: + Build Graph used to be the input to the calibrator + input_fn: + Input Function specified by the user + debug: + Defaults to False. Returns the calibrator + """ + + if trainer._estimator.config.is_chief: + + # overwrite the current save_dir + if params.overwrite_save_dir and tf.io.gfile.exists(params.calibrator_save_dir): + logging.info("Trainer overwriting existing save directory: %s (params.overwrite_save_dir)" + % params.calibrator_save_dir) + tf.io.gfile.rmtree(params.calibrator_save_dir) + + calibrator = IsotonicCalibrator(params.calibrator_num_bins) + + # chief trains discretizer + logging.info("Chief training calibrator") + + # Accumulate the features for each calibrator + features, labels = input_fn() + if 'weights' not in features: + raise ValueError("Weights need to be returned as part of the parse_fn") + weights = features.pop('weights') + + preds = build_graph(features=features, label=None, mode='infer', params=params, config=None) + init = tf.global_variables_initializer() + table_init = tf.tables_initializer() + with tf.Session() as sess: + sess.run(init) + sess.run(table_init) + count = 0 + max_steps = params.calibrator_max_steps or -1 + while max_steps <= 0 or count <= max_steps: + try: + weights_vals, labels_vals, preds_vals = sess.run([weights, labels, preds['output']]) + calibrator.accumulate(preds_vals, labels_vals, weights_vals.flatten()) + except tf.errors.OutOfRangeError: + break + count += 1 + + calibrator.calibrate() + calibrator.save(params.calibrator_save_dir) + trainer.estimator._params.isotonic_calibrator = True + + if debug: + return calibrator + + else: + calibrator_save_dir = twml.util.sanitize_hdfs_path(params.calibrator_save_dir) + # workers wait for calibration to be ready + while not tf.io.gfile.exists(calibrator_save_dir + os.path.sep + "tfhub_module.pb"): + logging.info("Worker waiting for calibration at %s" % calibrator_save_dir) + time.sleep(60) + + +def discretize(params, feature_config, input_fn, debug=False): + """ + Discretizes continuous features + Arguments: + params: + Parameters + input_fn: + Input Function specified by the user + debug: + Defaults to False. Returns the calibrator + """ + + if (os.environ.get("TWML_HOGWILD_TASK_TYPE") == "chief" or "num_workers" not in params or + params.num_workers is None): + + # overwrite the current save_dir + if params.overwrite_save_dir and tf.io.gfile.exists(params.discretizer_save_dir): + logging.info("Trainer overwriting existing save directory: %s (params.overwrite_save_dir)" + % params.discretizer_save_dir) + tf.io.gfile.rmtree(params.discretizer_save_dir) + + config_map = feature_config() + discretize_dict = config_map['discretize_config'] + + # chief trains discretizer + logging.info("Chief training discretizer") + + batch = input_fn() + # Accumulate the features for each calibrator + with tf.Session() as sess: + count = 0 + max_steps = params.discretizer_max_steps or -1 + while max_steps <= 0 or count <= max_steps: + try: + inputs = sess.run(batch) + for name, clbrt in discretize_dict.items(): + clbrt.accumulate_features(inputs[0], name) + except tf.errors.OutOfRangeError: + break + count += 1 + + # This module allows for the calibrator to save be saved as part of + # Tensorflow Hub (this will allow it to be used in further steps) + def calibrator_module(): + # Note that this is usually expecting a sparse_placeholder + for name, clbrt in discretize_dict.items(): + clbrt.calibrate() + clbrt.add_hub_signatures(name) + + # exports the module to the save_dir + spec = hub.create_module_spec(calibrator_module) + with tf.Graph().as_default(): + module = hub.Module(spec) + with tf.Session() as session: + module.export(params.discretizer_save_dir, session) + + for name, clbrt in discretize_dict.items(): + clbrt.write_summary_json(params.discretizer_save_dir, name) + + if debug: + return discretize_dict + + else: + # wait for the file to be removed (if necessary) + # should be removed after an actual fix applied + time.sleep(60) + discretizer_save_dir = twml.util.sanitize_hdfs_path(params.discretizer_save_dir) + # workers wait for calibration to be ready + while not tf.io.gfile.exists(discretizer_save_dir + os.path.sep + "tfhub_module.pb"): + logging.info("Worker waiting for calibration at %s" % discretizer_save_dir) + time.sleep(60) + + +def add_discretizer_arguments(parser): + """ + Add discretizer-specific command-line arguments to a Trainer parser. + + Arguments: + parser: argparse.ArgumentParser instance obtained from Trainer.get_trainer_parser + + Returns: + argparse.ArgumentParser instance with discretizer-specific arguments added + """ + + parser.add_argument("--discretizer.save_dir", type=str, + dest="discretizer_save_dir", + help="Path to save or load discretizer calibration") + parser.add_argument("--discretizer.batch_size", type=int, default=128, + dest="discretizer_batch_size", + help="Discretizer batch size") + parser.add_argument("--discretizer.keep_rate", type=float, default=0.0008, + dest="discretizer_keep_rate", + help="Keep rate") + parser.add_argument("--discretizer.parts_downsampling_rate", type=float, default=0.2, + dest="discretizer_parts_downsampling_rate", + help="Parts downsampling rate") + parser.add_argument("--discretizer.num_bins", type=int, default=20, + dest="discretizer_num_bins", + help="Number of bins per feature") + parser.add_argument("--discretizer.output_size_bits", type=int, default=22, + dest="discretizer_output_size_bits", + help="Number of bits allocated to the output size") + return parser + + +def add_isotonic_calibrator_arguments(parser): + """ + Add discretizer-specific command-line arguments to a Trainer parser. + + Arguments: + parser: argparse.ArgumentParser instance obtained from Trainer.get_trainer_parser + + Returns: + argparse.ArgumentParser instance with discretizer-specific arguments added + """ + parser.add_argument("--calibrator.num_bins", type=int, + default=25000, dest="calibrator_num_bins", + help="number of bins for isotonic calibration") + parser.add_argument("--calibrator.parts_downsampling_rate", type=float, default=0.1, + dest="calibrator_parts_downsampling_rate", help="Parts downsampling rate") + parser.add_argument("--calibrator.save_dir", type=str, + dest="calibrator_save_dir", help="Path to save or load calibrator output") + parser.add_argument("--calibrator.load_tensorflow_module", type=str, default=None, + dest="calibrator_load_tensorflow_module", + help="Location from where to load a pretrained graph from. \ + Typically, this is where the MLP graph is saved") + parser.add_argument("--calibrator.export_mlp_module_name", type=str, default='tf_hub_mlp', + help="Name for loaded hub signature", + dest="export_mlp_module_name") + parser.add_argument("--calibrator.export_isotonic_module_name", + type=str, default="tf_hub_isotonic", + dest="calibrator_export_module_name", + help="export module name") + parser.add_argument("--calibrator.final_evaluation_steps", type=int, + dest="calibrator_final_evaluation_steps", default=None, + help="number of steps for final evaluation") + parser.add_argument("--calibrator.train_steps", type=int, default=-1, + dest="calibrator_train_steps", + help="number of steps for calibration") + parser.add_argument("--calibrator.batch_size", type=int, default=1024, + dest="calibrator_batch_size", + help="Calibrator batch size") + parser.add_argument("--calibrator.is_calibrating", action='store_true', + dest="is_calibrating", + help="Dummy argument to allow running in chief worker") + return parser + + +def calibrate_calibrator_and_export(name, calibrator, build_graph_fn, params, feature_config, + run_eval=True, input_fn=None, metric_fn=None, + export_task_type_overrider=None): + """ + Pre-set `isotonic calibrator` calibrator. + Args: + name: + scope name used for the calibrator + calibrator: + calibrator that will be calibrated and exported. + build_graph_fn: + build graph function for the calibrator + params: + params passed to the calibrator + feature_config: + feature config which will be passed to the trainer + export_task_type_overrider: + the task type for exporting the calibrator + if specified, this will override the default export task type in trainer.hub_export(..) + """ + + # create calibrator params + params_c = copy.deepcopy(params) + params_c.data_threads = 1 + params_c.num_workers = 1 + params_c.continue_from_checkpoint = True + params_c.overwrite_save_dir = False + params_c.stats_port = None + + # Automatically load from the saved Tensorflow Hub module if not specified. + if params_c.calibrator_load_tensorflow_module is None: + path_saved_tensorflow_model = os.path.join(params.save_dir, params.export_mlp_module_name) + params_c.calibrator_load_tensorflow_module = path_saved_tensorflow_model + + if "calibrator_parts_downsampling_rate" in params_c: + params_c.train_parts_downsampling_rate = params_c.calibrator_parts_downsampling_rate + if "calibrator_save_dir" in params_c: + params_c.save_dir = params_c.calibrator_save_dir + if "calibrator_batch_size" in params_c: + params_c.train_batch_size = params_c.calibrator_batch_size + params_c.eval_batch_size = params_c.calibrator_batch_size + # TODO: Deprecate this option. It is not actually used. Calibrator + # simply iterates until the end of input_fn. + if "calibrator_train_steps" in params_c: + params_c.train_steps = params_c.calibrator_train_steps + + if metric_fn is None: + metric_fn = twml.metrics.get_multi_binary_class_metric_fn(None) + + # Common Trainer which will also be used by all workers + trainer = twml.trainers.DataRecordTrainer( + name=name, + params=params_c, + feature_config=feature_config, + build_graph_fn=build_graph_fn, + save_dir=params_c.save_dir, + metric_fn=metric_fn + ) + + if trainer._estimator.config.is_chief: + + # Chief trains calibrator + logging.info("Chief training calibrator") + + # Disregard hogwild config + os_twml_hogwild_ports = os.environ.get("TWML_HOGWILD_PORTS") + os.environ["TWML_HOGWILD_PORTS"] = "" + + hooks = None + if params_c.calibrator_train_steps > 0: + hooks = [twml.hooks.StepProgressHook(params_c.calibrator_train_steps)] + + def parse_fn(input_x): + fc_parse_fn = feature_config.get_parse_fn() + features, labels = fc_parse_fn(input_x) + features['labels'] = labels + return features, labels + + if input_fn is None: + input_fn = trainer.get_train_input_fn(parse_fn=parse_fn, repeat=False) + + # Calibrate stage + trainer.estimator._params.mode = 'calibrate' + trainer.calibrate(calibrator=calibrator, + input_fn=input_fn, + steps=params_c.calibrator_train_steps, + hooks=hooks) + + # Save Checkpoint + # We need to train for 1 step, to save the graph to checkpoint. + # This is done just by the chief. + # We need to set the mode to evaluate to save the graph that will be consumed + # In the final evaluation + trainer.estimator._params.mode = 'evaluate' + trainer.train(input_fn=input_fn, steps=1) + + # Restore hogwild setup + if os_twml_hogwild_ports is not None: + os.environ["TWML_HOGWILD_PORTS"] = os_twml_hogwild_ports + else: + # Workers wait for calibration to be ready + final_calibrator_path = os.path.join(params_c.calibrator_save_dir, + params_c.calibrator_export_module_name) + + final_calibrator_path = twml.util.sanitize_hdfs_path(final_calibrator_path) + + while not tf.io.gfile.exists(final_calibrator_path + os.path.sep + "tfhub_module.pb"): + logging.info("Worker waiting for calibration at %s" % final_calibrator_path) + time.sleep(60) + + # Evaluate stage + if run_eval: + trainer.estimator._params.mode = 'evaluate' + # This will allow the Evaluate method to be run in Hogwild + # trainer.estimator._params.continue_from_checkpoint = True + trainer.evaluate(name='test', input_fn=input_fn, steps=params_c.calibrator_final_evaluation_steps) + + trainer.hub_export(name=params_c.calibrator_export_module_name, + export_task_type_overrider=export_task_type_overrider, + serving_input_receiver_fn=feature_config.get_serving_input_receiver_fn()) + + return trainer + + +def calibrate_discretizer_and_export(name, calibrator, build_graph_fn, params, feature_config): + """ + Pre-set percentile discretizer calibrator. + Args: + name: + scope name used for the calibrator + calibrator: + calibrator that will be calibrated and exported. + build_graph_fn: + build graph function for the calibrator + params: + params passed to the calibrator + feature_config: + feature config or input_fn which will be passed to the trainer. + """ + + if (os.environ.get("TWML_HOGWILD_TASK_TYPE") == "chief" or "num_workers" not in params or + params.num_workers is None): + + # chief trains discretizer + logging.info("Chief training discretizer") + + # disregard hogwild config + os_twml_hogwild_ports = os.environ.get("TWML_HOGWILD_PORTS") + os.environ["TWML_HOGWILD_PORTS"] = "" + + # create discretizer params + params_c = copy.deepcopy(params) + params_c.data_threads = 1 + params_c.train_steps = -1 + params_c.train_max_steps = None + params_c.eval_steps = -1 + params_c.num_workers = 1 + params_c.tensorboard_port = None + params_c.stats_port = None + + if "discretizer_batch_size" in params_c: + params_c.train_batch_size = params_c.discretizer_batch_size + params_c.eval_batch_size = params_c.discretizer_batch_size + if "discretizer_keep_rate" in params_c: + params_c.train_keep_rate = params_c.discretizer_keep_rate + if "discretizer_parts_downsampling_rate" in params_c: + params_c.train_parts_downsampling_rate = params_c.discretizer_parts_downsampling_rate + if "discretizer_save_dir" in params_c: + params_c.save_dir = params_c.discretizer_save_dir + + # train discretizer + trainer = twml.trainers.DataRecordTrainer( + name=name, + params=params_c, + build_graph_fn=build_graph_fn, + save_dir=params_c.save_dir, + ) + + if isinstance(feature_config, twml.feature_config.FeatureConfig): + parse_fn = twml.parsers.get_continuous_parse_fn(feature_config) + input_fn = trainer.get_train_input_fn(parse_fn=parse_fn, repeat=False) + elif callable(feature_config): + input_fn = feature_config + else: + got_type = type(feature_config).__name__ + raise ValueError( + "Expecting feature_config to be FeatureConfig or function got %s" % got_type) + + hooks = None + if params_c.train_steps > 0: + hooks = [twml.hooks.StepProgressHook(params_c.train_steps)] + + trainer.calibrate(calibrator=calibrator, input_fn=input_fn, + steps=params_c.train_steps, hooks=hooks) + # restore hogwild setup + if os_twml_hogwild_ports is not None: + os.environ["TWML_HOGWILD_PORTS"] = os_twml_hogwild_ports + else: + discretizer_save_dir = twml.util.sanitize_hdfs_path(params.discretizer_save_dir) + # workers wait for calibration to be ready + while not tf.io.gfile.exists(discretizer_save_dir + os.path.sep + "tfhub_module.pb"): + logging.info("Worker waiting for calibration at %s" % discretizer_save_dir) + time.sleep(60) + + +def build_percentile_discretizer_graph(features, label, mode, params, config=None): + """ + Pre-set Percentile Discretizer Build Graph + Follows the same signature as build_graph + """ + sparse_tf = twml.util.convert_to_sparse(features, params.input_size_bits) + weights = tf.reshape(features['weights'], tf.reshape(features['batch_size'], [1])) + if isinstance(sparse_tf, tf.SparseTensor): + indices = sparse_tf.indices[:, 1] + ids = sparse_tf.indices[:, 0] + elif isinstance(sparse_tf, twml.SparseTensor): + indices = sparse_tf.indices + ids = sparse_tf.ids + + # Return weights, feature_ids, feature_values + weights = tf.gather(params=weights, indices=ids) + feature_ids = indices + feature_values = sparse_tf.values + # Update train_op and assign dummy_loss + train_op = tf.assign_add(tf.train.get_global_step(), 1) + loss = tf.constant(1) + if mode == 'train': + return {'train_op': train_op, 'loss': loss} + return {'feature_ids': feature_ids, 'feature_values': feature_values, 'weights': weights} + + +def isotonic_module(mode, params): + """ + Common Isotonic Calibrator module for Hub Export + """ + inputs = tf.sparse_placeholder(tf.float32, name="sparse_input") + mlp = hub.Module(params.calibrator_load_tensorflow_module) + logits = mlp(inputs, signature=params.export_mlp_module_name) + isotonic_calibrator = hub.Module(params.save_dir) + output = isotonic_calibrator(logits, signature="isotonic_calibrator") + hub.add_signature(inputs={"sparse_input": inputs}, + outputs={"default": output}, + name=params.calibrator_export_module_name) + + +def build_isotonic_graph_from_inputs(inputs, features, label, mode, params, config=None, isotonic_fn=None): + """ + Helper function to build_isotonic_graph + Pre-set Isotonic Calibrator Build Graph + Follows the same signature as build_graph + """ + if params.mode == 'calibrate': + mlp = hub.Module(params.calibrator_load_tensorflow_module) + logits = mlp(inputs, signature=params.export_mlp_module_name) + weights = tf.reshape(features['weights'], tf.reshape(features['batch_size'], [1])) + # Update train_op and assign dummy_loss + train_op = tf.assign_add(tf.train.get_global_step(), 1) + loss = tf.constant(1) + if mode == 'train': + return {'train_op': train_op, 'loss': loss} + return {'predictions': logits, 'targets': features['labels'], 'weights': weights} + else: + if isotonic_fn is None: + isotonic_spec = twml.util.create_module_spec(mlp_fn=isotonic_module, mode=mode, params=params) + else: + isotonic_spec = twml.util.create_module_spec(mlp_fn=isotonic_fn, mode=mode, params=params) + output_hub = hub.Module(isotonic_spec, + name=params.calibrator_export_module_name) + hub.register_module_for_export(output_hub, params.calibrator_export_module_name) + output = output_hub(inputs, signature=params.calibrator_export_module_name) + output = tf.clip_by_value(output, 0, 1) + loss = tf.reduce_sum(tf.stop_gradient(output)) + train_op = tf.assign_add(tf.train.get_global_step(), 1) + return {'train_op': train_op, 'loss': loss, 'output': output} + + +def build_isotonic_graph(features, label, mode, params, config=None, export_discretizer=True): + """ + Pre-set Isotonic Calibrator Build Graph + Follows the same signature as build_graph + This assumes that MLP already contains all modules (include percentile + discretizer); if export_discretizer is set + then it does not export the MDL phase. + """ + sparse_tf = twml.util.convert_to_sparse(features, params.input_size_bits) + if export_discretizer: + return build_isotonic_graph_from_inputs(sparse_tf, features, label, mode, params, config) + discretizer = hub.Module(params.discretizer_path) + + if params.discretizer_signature is None: + discretizer_signature = "percentile_discretizer_calibrator" + else: + discretizer_signature = params.discretizer_signature + input_sparse = discretizer(sparse_tf, signature=discretizer_signature) + return build_isotonic_graph_from_inputs(input_sparse, features, label, mode, params, config) diff --git a/twml/twml/contrib/calibrators/hashed_percentile_discretizer.py b/twml/twml/contrib/calibrators/hashed_percentile_discretizer.py new file mode 100644 index 000000000..e14f62303 --- /dev/null +++ b/twml/twml/contrib/calibrators/hashed_percentile_discretizer.py @@ -0,0 +1,22 @@ +# pylint: disable=arguments-differ,no-member,too-many-statements +''' Contains HashedPercentileDiscretizerCalibrator used for calibration ''' +from .percentile_discretizer import PercentileDiscretizerCalibrator + +import twml + + +class HashedPercentileDiscretizerCalibrator(PercentileDiscretizerCalibrator): + ''' Accumulates features and their respective values for HashedPercentileDiscretizer calibration. + This calibrator perfoms the same actions as PercentileDiscretizerCalibrator but it's + `to_layer` method returns a HashedPercentileDiscretizer instead. + ''' + + def _create_discretizer_layer(self, n_feature, hash_map_keys, hash_map_values, + feature_offsets, name): + return twml.contrib.layers.HashedPercentileDiscretizer( + n_feature=n_feature, n_bin=self._n_bin, + name=name, out_bits=self._out_bits, + hash_keys=hash_map_keys, hash_values=hash_map_values, + bin_ids=self._bin_ids.flatten(), bin_values=self._bin_vals.flatten(), + feature_offsets=feature_offsets + ) diff --git a/twml/twml/contrib/calibrators/hashing_discretizer.py b/twml/twml/contrib/calibrators/hashing_discretizer.py new file mode 100644 index 000000000..965ced934 --- /dev/null +++ b/twml/twml/contrib/calibrators/hashing_discretizer.py @@ -0,0 +1,35 @@ +# pylint: disable=arguments-differ,no-member,too-many-statements +''' Contains HashedPercentileDiscretizerCalibrator used for calibration ''' +from .percentile_discretizer import PercentileDiscretizerCalibrator + +import numpy as np +import twml + + +class HashingDiscretizerCalibrator(PercentileDiscretizerCalibrator): + ''' Accumulates features and their respective values for HashingDiscretizer calibration. + This calibrator perfoms the same actions as PercentileDiscretizerCalibrator but it's + `to_layer` method returns a HashingDiscretizer instead. + ''' + + def _create_discretizer_layer(self, n_feature, hash_map_keys, hash_map_values, + feature_offsets, name): + # Need to sort hash_map_keys according to hash_map_values + # just in case they're not in order of being put in the dict + # hash_map_values is already 0 through len(hash_map_values)-1 + hash_map_keys = hash_map_keys.flatten() + # why is this float32 in PercentileDiscretizerCalibrator.to_layer ???? + # need int for indexing + hash_map_values = hash_map_values.flatten().astype(np.int32) + feature_ids = np.zeros((len(hash_map_keys),), dtype=np.int64) + for idx in range(len(hash_map_keys)): + feature_ids[hash_map_values[idx]] = hash_map_keys[idx] + + return twml.contrib.layers.HashingDiscretizer( + feature_ids=feature_ids, + bin_vals=self._bin_vals.flatten(), + n_bin=self._n_bin + 1, # (self._n_bin + 1) bin_vals for each feature_id + out_bits=self._out_bits, + cost_per_unit=500, + name=name + ) diff --git a/twml/twml/contrib/calibrators/isotonic.py b/twml/twml/contrib/calibrators/isotonic.py new file mode 100644 index 000000000..d03a75ff8 --- /dev/null +++ b/twml/twml/contrib/calibrators/isotonic.py @@ -0,0 +1,317 @@ +# pylint: disable=arguments-differ, unused-argument +''' Contains Isotonic Calibration''' + +from .calibrator import CalibrationFeature, Calibrator + +from absl import logging +import numpy as np +from sklearn.isotonic import isotonic_regression +import tensorflow.compat.v1 as tf +import tensorflow_hub as hub +import twml +import twml.layers + + +DEFAULT_SAMPLE_WEIGHT = 1 + + +def sort_values(inputs, target, weight, ascending=True): + ''' + Sorts arrays based on the first array. + + Arguments: + inputs: + 1D array which will dictate the order which the remainder 2 arrays will be sorted + target: + 1D array + weight: + 1D array + ascending: + Boolean. If set to True (the default), sorts values in ascending order. + + Returns: + sorted inputs: + 1D array sorted by the order of `ascending` + sorted targets: + 1D array + sorted weight: + 1D array + ''' + # assert that the length of inputs and target are the same + if len(inputs) != len(target): + raise ValueError('Expecting inputs and target sizes to match') + # assert that the length of inputs and weight are the same + if len(inputs) != len(weight): + raise ValueError('Expecting inputs and weight sizes to match') + inds = inputs.argsort() + if not ascending: + inds = inds[::-1] + return inputs[inds], target[inds], weight[inds] + + +class IsotonicFeature(CalibrationFeature): + ''' + IsotonicFeature adds values, weights and targets to each feature and then runs + isotonic regression by calling `sklearn.isotonic.isotonic_regression + `_ + ''' + + def _get_bin_boundaries(self, n_samples, bins, similar_bins): + """ + Calculates the sample indices that define bin boundaries + + Arguments: + n_samples: + (int) number of samples + bins: + (int) number of bins. Needs to be smaller or equal than n_samples. + similar_bins: + (bool) If True, samples will be distributed in bins of equal size (up to one sample). + If False bins will be filled with step = N_samples//bins, and last bin will contain all remaining samples. + Note that equal_bins=False can create a last bins with a very large number of samples. + + Returns: + (list[int]) List of sample indices defining bin boundaries + """ + + if bins > n_samples: + raise ValueError( + "The number of bins needs to be less than or equal to the number of samples. " + "Currently bins={0} and n_samples={1}.".format(bins, n_samples) + ) + + step = n_samples // bins + + if similar_bins: + # dtype=int will floor the linspace + bin_boundaries = np.linspace(0, n_samples - step, num=bins, dtype=int) + else: + bin_boundaries = range(0, step * bins, step) + + bin_boundaries = np.append(bin_boundaries, n_samples) + + return bin_boundaries + + def calibrate(self, bins, similar_bins=False, debug=False): + '''Calibrates the IsotonicFeature into calibrated weights and bias. + + 1. Sorts the values of the feature class, based on the order of values + 2. Performs isotonic regression using sklearn.isotonic.isotonic_regression + 3. Performs the binning of the samples, in order to obtain the final weight and bias + which will be used for inference + + Note that this method can only be called once. + + Arguments: + bins: + number of bins. + similar_bins: + If True, samples will be distributed in bins of equal size (up to one sample). + If False bins will be filled with step = N_samples//bins, and last bin will contain all remaining samples. + Note that equal_bins=False can create a last bins with a very large number of samples. + debug: + Defaults to False. If debug is set to true, output other parameters useful for debugging. + + Returns: + [calibrated weight, calibrated bias] + ''' + if self._calibrated: + raise RuntimeError("Can only calibrate once") + # parse through the dict to obtain the targets, weights and values + self._concat_arrays() + feature_targets = self._features_dict['targets'] + feature_values = self._features_dict['values'] + feature_weights = self._features_dict['weights'] + srtd_feature_values, srtd_feature_targets, srtd_feature_weights = sort_values( + inputs=feature_values, + target=feature_targets, + weight=feature_weights + ) + calibrated_feature_values = isotonic_regression( + srtd_feature_targets, sample_weight=srtd_feature_weights) + # create the final outputs for the prediction of each class + bpreds = [] + btargets = [] + bweights = [] + rpreds = [] + + # Create bin boundaries + bin_boundaries = self._get_bin_boundaries( + len(calibrated_feature_values), bins, similar_bins=similar_bins) + + for sidx, eidx in zip(bin_boundaries, bin_boundaries[1:]): + # separate each one of the arrays based on their respective bins + lpreds = srtd_feature_values[int(sidx):int(eidx)] + lrpreds = calibrated_feature_values[int(sidx):int(eidx)] + ltargets = srtd_feature_targets[int(sidx):int(eidx)] + lweights = srtd_feature_weights[int(sidx):int(eidx)] + + # calculate the outputs (including the bpreds and rpreds) + bpreds.append(np.sum(lpreds * lweights) / (np.squeeze(np.sum(lweights)))) + rpreds.append(np.sum(lrpreds * lweights) / (np.squeeze(np.sum(lweights)))) + btargets.append(np.sum(ltargets * lweights) / (np.squeeze(np.sum(lweights)))) + bweights.append(np.squeeze(np.sum(lweights))) + # transposing the bpreds and rpreds which will be used as input to the inference step + bpreds = np.asarray(bpreds).T + rpreds = np.asarray(rpreds).T + btargets = np.asarray(btargets).T + bweights = np.asarray(bweights).T + # setting _calibrated to be True which is necessary in order to prevent it to re-calibrate + self._calibrated = True + if debug: + return bpreds, rpreds, btargets, bweights + return bpreds, rpreds + + +class IsotonicCalibrator(Calibrator): + ''' Accumulates features and their respective values for isotonic calibration. + Internally, each feature's values is accumulated via its own isotonicFeature object. + The steps for calibration are typically as follows: + + 1. accumulate feature values from batches by calling ``accumulate()``; + 2. calibrate all feature into Isotonic ``bpreds``, ``rpreds`` by calling ``calibrate()``; and + 3. convert to a ``twml.layers.Isotonic`` layer by calling ``to_layer()``. + + ''' + + def __init__(self, n_bin, similar_bins=False, **kwargs): + ''' Constructs an isotonicCalibrator instance. + + Arguments: + n_bin: + the number of bins per feature to use for isotonic. + Note that each feature actually maps to ``n_bin+1`` output IDs. + ''' + super(IsotonicCalibrator, self).__init__(**kwargs) + self._n_bin = n_bin + self._similar_bins = similar_bins + self._ys_input = [] + self._xs_input = [] + self._isotonic_feature_dict = {} + + def accumulate_feature(self, output): + ''' + Wrapper around accumulate for trainer API. + Arguments: + output: output of prediction of build_graph for calibrator + ''' + weights = output['weights'] if 'weights' in output else None + return self.accumulate(output['predictions'], output['targets'], weights) + + def accumulate(self, predictions, targets, weights=None): + ''' + Accumulate a single batch of class predictions, class targets and class weights. + These are accumulated until calibrate() is called. + + Arguments: + predictions: + float matrix of class values. Each dimension corresponds to a different class. + Shape is ``[n, d]``, where d is the number of classes. + targets: + float matrix of class targets. Each dimension corresponds to a different class. + Shape ``[n, d]``, where d is the number of classes. + weights: + Defaults to weights of 1. + 1D array containing the weights of each prediction. + ''' + if predictions.shape != targets.shape: + raise ValueError( + 'Expecting predictions.shape == targets.shape, got %s and %s instead' % + (str(predictions.shape), str(targets.shape))) + if weights is not None: + if weights.ndim != 1: + raise ValueError('Expecting 1D weight, got %dD instead' % weights.ndim) + elif weights.size != predictions.shape[0]: + raise ValueError( + 'Expecting predictions.shape[0] == weights.size, got %d != %d instead' % + (predictions.shape[0], weights.size)) + # iterate through the rows of predictions and sets one class to each row + if weights is None: + weights = np.full(predictions.shape[0], fill_value=DEFAULT_SAMPLE_WEIGHT) + for class_key in range(predictions.shape[1]): + # gets the predictions and targets for that class + class_predictions = predictions[:, class_key] + class_targets = targets[:, class_key] + if class_key not in self._isotonic_feature_dict: + isotonic_feature = IsotonicFeature(class_key) + self._isotonic_feature_dict[class_key] = isotonic_feature + else: + isotonic_feature = self._isotonic_feature_dict[class_key] + isotonic_feature.add_values({'values': class_predictions, 'weights': weights, + 'targets': class_targets}) + + def calibrate(self, debug=False): + ''' + Calibrates each IsotonicFeature after accumulation is complete. + Results are stored in ``self._ys_input`` and ``self._xs_input`` + + Arguments: + debug: + Defaults to False. If set to true, returns the ``xs_input`` and ``ys_input``. + ''' + super(IsotonicCalibrator, self).calibrate() + bias_temp = [] + weight_temp = [] + logging.info("Beginning isotonic calibration.") + isotonic_features_dict = self._isotonic_feature_dict + for class_id in isotonic_features_dict: + bpreds, rpreds = isotonic_features_dict[class_id].calibrate(bins=self._n_bin, similar_bins=self._similar_bins) + weight_temp.append(bpreds) + bias_temp.append(rpreds) + # save isotonic results onto a matrix + self._xs_input = np.array(weight_temp, dtype=np.float32) + self._ys_input = np.array(bias_temp, dtype=np.float32) + logging.info("Isotonic calibration finished.") + if debug: + return np.array(weight_temp), np.array(bias_temp) + return None + + def save(self, save_dir, name="default", verbose=False): + '''Save the calibrator into the given save_directory. + Arguments: + save_dir: + name of the saving directory. Default (string): "default". + ''' + if not self._calibrated: + raise RuntimeError("Expecting prior call to calibrate().Cannot save() prior to calibrate()") + + # This module allows for the calibrator to save be saved as part of + # Tensorflow Hub (this will allow it to be used in further steps) + logging.info("You probably do not need to save the isotonic layer. \ + So feel free to set save to False in the Trainer. \ + Additionally this only saves the layer not the whole graph.") + + def calibrator_module(): + ''' + Way to save Isotonic layer + ''' + # The input to isotonic is a dense layer + inputs = tf.placeholder(tf.float32) + calibrator_layer = self.to_layer() + output = calibrator_layer(inputs) + # creates the signature to the calibrator module + hub.add_signature(inputs=inputs, outputs=output, name=name) + + # exports the module to the save_dir + spec = hub.create_module_spec(calibrator_module) + with tf.Graph().as_default(): + module = hub.Module(spec) + with tf.Session() as session: + module.export(save_dir, session) + + def to_layer(self): + """ Returns a twml.layers.Isotonic Layer that can be used for feature discretization. + """ + if not self._calibrated: + raise RuntimeError("Expecting prior call to calibrate()") + + isotonic_layer = twml.layers.Isotonic( + n_unit=self._xs_input.shape[0], n_bin=self._xs_input.shape[1], + xs_input=self._xs_input, ys_input=self._ys_input, + **self._kwargs) + + return isotonic_layer + + def get_layer_args(self, name=None): + """ Returns layer args. See ``Calibrator.get_layer_args`` for more detailed documentation """ + return {'n_unit': self._xs_input.shape[0], 'n_bin': self._xs_input.shape[1]} diff --git a/twml/twml/contrib/calibrators/mdl.py b/twml/twml/contrib/calibrators/mdl.py new file mode 100644 index 000000000..0fe3265a4 --- /dev/null +++ b/twml/twml/contrib/calibrators/mdl.py @@ -0,0 +1,118 @@ +# pylint: disable=arguments-differ,no-member,too-many-statements +''' Contains MDLFeature and MDLCalibrator used for MDL calibration ''' + + +import os + +from .percentile_discretizer import PercentileDiscretizerCalibrator, PercentileDiscretizerFeature + +from absl import logging +import numpy as np +import tensorflow.compat.v1 as tf +import twml +import twml.layers + + +DEFAULT_SAMPLE_WEIGHT = 1 + + +class MDLFeature(PercentileDiscretizerFeature): + ''' Accumulates and calibrates a single sparse MDL feature. ''' + + +class MDLCalibrator(PercentileDiscretizerCalibrator): + ''' Accumulates features and their respective values for MDL calibration. + Internally, each feature's values is accumulated via its own ``MDLFeature`` object. + The steps for calibration are typically as follows: + + 1. accumulate feature values from batches by calling ``accumulate()``; + 2. calibrate all feature into MDL bin_vals by calling ``calibrate()``; and + 3. convert to a twml.layers.MDL layer by calling ``to_layer()``. + + ''' + + def to_layer(self, name=None): + """ + Returns a twml.layers.PercentileDiscretizer Layer + that can be used for feature discretization. + + Arguments: + name: + name-scope of the PercentileDiscretizer layer + """ + n_feature = len(self._discretizer_feature_dict) + max_discretizer_feature = n_feature * (self._n_bin + 1) + + if not self._calibrated: + raise RuntimeError("Expecting prior call to calibrate()") + + if self._bin_ids.shape[0] != n_feature: + raise RuntimeError("Expecting self._bin_ids.shape[0] \ + != len(self._discretizer_feature_dict)") + if self._bin_vals.shape[0] != n_feature: + raise RuntimeError("Expecting self._bin_vals.shape[0] \ + != len(self._discretizer_feature_dict)") + + # can add at most #features * (n_bin+1) new feature ids + if 2**self._out_bits <= max_discretizer_feature: + raise ValueError("""Maximum number of features created by discretizer is + %d but requested that the output be limited to %d values (%d bits), + which is smaller than that. Please ensure the output has enough bits + to represent at least the new features""" + % (max_discretizer_feature, 2**self._out_bits, self._out_bits)) + + # build feature_offsets, hash_map_keys, hash_map_values + feature_offsets = np.arange(0, max_discretizer_feature, + self._n_bin + 1, dtype='int64') + hash_map_keys = np.array(list(self._hash_map.keys()), dtype=np.int64) + hash_map_values = np.array(list(self._hash_map.values()), dtype=np.float32) + + discretizer = twml.layers.MDL( + n_feature=n_feature, n_bin=self._n_bin, + name=name, out_bits=self._out_bits, + hash_keys=hash_map_keys, hash_values=hash_map_values, + bin_ids=self._bin_ids.flatten(), bin_values=self._bin_vals.flatten(), + feature_offsets=feature_offsets, + **self._kwargs + ) + + return discretizer + + def save(self, save_dir, name='calibrator', verbose=False): + '''Save the calibrator into the given save_directory. + Arguments: + save_dir: + name of the saving directory + name: + name for the graph scope. Passed to to_layer(name=name) to set + scope of layer. + ''' + if not self._calibrated: + raise RuntimeError("Expecting prior call to calibrate().Cannot save() prior to calibrate()") + + layer_args = self.get_layer_args() + + calibrator_filename = os.path.join(save_dir, name + '.json.tf') + calibrator_dict = { + 'layer_args': layer_args, + 'saved_layer_scope': name + '/', + } + twml.write_file(calibrator_filename, calibrator_dict, encode='json') + + if verbose: + logging.info("The layer graph and other information necessary ") + logging.info("for multi-phase training is saved in directory:") + logging.info(save_dir) + logging.info("This directory can be specified as --init_from_dir argument.") + logging.info("") + logging.info("Other information is available in: %s.json.tf", name) + logging.info("This file can be loaded with twml.read_file(decode='json) to obtain ") + logging.info("layer_args, saved_layer_scope and variable_names") + + graph = tf.Graph() + # save graph for tensorboard as well + writer = tf.summary.FileWriter(logdir=save_dir, graph=graph) + + with tf.Session(graph=graph) as sess: + self.write_summary(writer, sess) + writer.flush() diff --git a/twml/twml/contrib/calibrators/percentile_discretizer.py b/twml/twml/contrib/calibrators/percentile_discretizer.py new file mode 100644 index 000000000..eefce62c2 --- /dev/null +++ b/twml/twml/contrib/calibrators/percentile_discretizer.py @@ -0,0 +1,577 @@ +# pylint: disable=arguments-differ,no-member,too-many-statements +''' Contains PercentileDiscretizerFeature and PercentileDiscretizerCalibrator used \ + for PercentileDiscretizer calibration ''' + + + +from .calibrator import CalibrationFeature, Calibrator + +import os +import numpy as np +import tensorflow.compat.v1 as tf +import tensorflow_hub as hub +import twml +import twml.layers + + +DEFAULT_SAMPLE_WEIGHT = 1 + + +class PercentileDiscretizerFeature(CalibrationFeature): + ''' Accumulates and calibrates a single sparse PercentileDiscretizer feature. ''' + + @staticmethod + def _gather_debug_info(values, indices, bin_vals, bin_counts_buffer): + ''' + Determine how many training values fell into a given bin during calibration. + This is calculated by finding the index of the first appearance of each bin + boundary in values (values may repeat, so that isn't trivially in indices.) + Subtracting each bin boundary index from the next tells you how many values fall in + that bin. + To get this to calculate the last bin correctly, len(values) is appended to the + list of bound indices. + + This assumes that ``bin_vals`` excludes np.inf bin boundaries when + PercentileDiscretizer was calibrated + with fewer values than bins. + + Arguments: + values: + 1D ndarray of the PercentileDiscretizerFeature's accumulated values, sorted ascending + indices: + 1D int32 ndarray of the indices (in values) of the bin boundaries + bin_vals: + 1D ndarray containing the bin boundaries + bin_counts_buffer: + ndarray buffer for returning the PercentileDiscretizer histogram + ''' + # np.flatnonzero(np.diff(x)) gives you the indices i in x s.t. x[i] != x[i+1] + # append index of the last bin since that cannot be empty with how + # PercentileDiscretizer is implemented + nonempty_bins = np.append(np.flatnonzero(np.diff(bin_vals)), len(bin_vals) - 1) + bin_start_indices = indices.take(nonempty_bins) + + # if multiples of a bin's lower bound value exist, find the first one + for (i, idx) in enumerate(bin_start_indices): + cur_idx = idx + while cur_idx > 0 and values[cur_idx] == values[cur_idx - 1]: + bin_start_indices[i] = cur_idx = cur_idx - 1 + + # the end of each bin is the start of the next bin, + # until the last, which is the end of the array + # broadcast the counts to the nonempty bins, 0 otherwise + bin_counts_buffer[:] = 0 + bin_counts_buffer[nonempty_bins] = np.diff(np.append(bin_start_indices, values.size)) + + def calibrate( + self, + bin_vals, percentiles, percentile_indices, + bin_counts_buffer=None): + '''Calibrates the PercentileDiscretizerFeature into bin values for + use in PercentileDiscretizerCalibrator. + Note that this method can only be called once. + + Arguments: + bin_vals: + Row in the PercentileDiscretizerCalibrator.bin_vals matrix corresponding to this feature. + Will be updated with the results of the calibration. + A 1D ndarray. + percentiles: + 1D array of size n_bin with values ranging from 0 to 1. + For example, ``percentiles = np.linspace(0, 1, num=self._n_bin+1, dtype=np.float32)`` + percentile_indices: + Empty 1D array of size n_bin used to store intermediate results when + calling twml.twml_optim_nearest_interpolation(). + For example, np.empty(self._n_bin + 1, dtype=np.float32). + bin_counts_buffer: + optional ndarray buffer used for retaining count of values per PercentileDiscretizer + bucket (for debug and feature exploration purposes) + + Returns: + calibrated bin_vals for use by ``PercentileDiscretizerCalibrator`` + ''' + if self._calibrated: + raise RuntimeError("Can only calibrate once") + if bin_vals.ndim != 1: + raise RuntimeError("Expecting bin_vals row") + + # # concatenate values and weights buffers + self._concat_arrays() + feature_values = self._features_dict['values'] + feature_weights = self._features_dict['weights'] + + # get features ready for the bins, order array indices by feature values. + indices = np.argsort(feature_values) + + # get ordered values and weights using array indices + values = feature_values.take(indices) + weights = feature_weights.take(indices) + + # Normalizes the sum of weights to be between 0 and 1 + weights = np.cumsum(weights, out=feature_weights) + weights -= weights[0] + if weights[-1] > 0: # prevent zero-division + weights /= weights[-1] + + # Check if we have less values than bin_vals + if values.size < bin_vals.size: + # Fills all the bins with a value that won't ever be reached + bin_vals.fill(np.inf) + # Forces the first to be -inf + bin_vals[0] = -np.inf + # Copies the values as boundaries + bin_vals[1:values.size + 1] = values + + if bin_counts_buffer is not None: + # slice out bins with +/-np.inf boundary -- their count will be zero anyway + # we can't just assume all other bins will have 1 value since there can be dups + short_indices = np.arange(values.size, dtype=np.int32) + bin_counts_buffer.fill(0) + self._gather_debug_info( + values, short_indices, bin_vals[1:values.size + 1], + bin_counts_buffer[1:values.size + 1]) + + else: + # Gets the indices for the values that define the boundary for the bins + indices_float = np.arange(0, weights.size, dtype=np.float32) + + # Gets things in the correct shape for the linear interpolation + weights = weights.reshape(1, weights.size) + indices_float = indices_float.reshape(1, weights.size) + + # wrap ndarrays into twml.Array + percentiles_tarray = twml.Array(percentiles.reshape(percentiles.size, 1)) + weights_tarray = twml.Array(weights) + indices_float_tarray = twml.Array(indices_float) + percentile_indices_tarray = twml.Array(percentile_indices.reshape(percentiles.size, 1)) + + # Performs the binary search to find the indices corresponding to the percentiles + err = twml.CLIB.twml_optim_nearest_interpolation( + percentile_indices_tarray.handle, percentiles_tarray.handle, # output, input + weights_tarray.handle, indices_float_tarray.handle # xs, ys + ) + if err != 1000: + raise ValueError("""twml.CLIB.twml_optim_nearest_interpolation + caught an error (see previous stdout). Error code: """ % err) + + indices = indices[:bin_vals.size] + indices[:] = percentile_indices + indices[0] = 0 + indices[-1] = weights.size - 1 + + # Gets the values at those indices and copies them into bin_vals + values.take(indices, out=bin_vals) + + # get # of values per bucket + if bin_counts_buffer is not None: + self._gather_debug_info(values, indices, bin_vals, bin_counts_buffer) + + self._calibrated = True + + +class PercentileDiscretizerCalibrator(Calibrator): + ''' Accumulates features and their respective values for PercentileDiscretizer calibration. + Internally, each feature's values is accumulated via its own + ``PercentileDiscretizerFeature`` object. + The steps for calibration are typically as follows: + + 1. accumulate feature values from batches by calling ``accumulate()``; + 2. calibrate all feature into PercentileDiscretizer bin_vals by calling ``calibrate()``; and + 3. convert to a twml.layers.PercentileDiscretizer layer by calling ``to_layer()``. + + ''' + + def __init__(self, n_bin, out_bits, bin_histogram=True, + allow_empty_calibration=False, **kwargs): + ''' Constructs an PercentileDiscretizerCalibrator instance. + + Arguments: + n_bin: + the number of bins per feature to use for PercentileDiscretizer. + Note that each feature actually maps to n_bin+1 output IDs. + out_bits: + The maximum number of bits to use for the output IDs. + 2**out_bits must be greater than bin_ids.size or an error is raised. + bin_histogram: + When True (the default), gathers information during calibration + to build a bin_histogram. + allow_empty_calibration: + allows operation where we might not calibrate any features. + Default False to error out if no features were calibrated. + Typically, values of uncalibrated features pass through discretizers + untouched (though the feature ids will be truncated to obey out_bits). + ''' + super(PercentileDiscretizerCalibrator, self).__init__(**kwargs) + self._n_bin = n_bin + self._out_bits = out_bits + + self._bin_ids = None + self._bin_vals = np.empty(0, dtype=np.float32) # Note changed from 64 (v1) to 32 (v2) + + self._bin_histogram = bin_histogram + self._bin_histogram_dict = None + + self._hash_map_counter = 0 + self._hash_map = {} + + self._discretizer_feature_dict = {} + self._allow_empty_calibration = allow_empty_calibration + + @property + def bin_ids(self): + ''' + Gets bin_ids + ''' + return self._bin_ids + + @property + def bin_vals(self): + ''' + Gets bin_vals + ''' + return self._bin_vals + + @property + def hash_map(self): + ''' + Gets hash_map + ''' + return self._hash_map + + @property + def discretizer_feature_dict(self): + ''' + Gets feature_dict + ''' + return self._discretizer_feature_dict + + def accumulate_features(self, inputs, name): + ''' + Wrapper around accumulate for PercentileDiscretizer. + Arguments: + inputs: + batch that will be accumulated + name: + name of the tensor that will be accumulated + + ''' + sparse_tf = inputs[name] + indices = sparse_tf.indices[:, 1] + ids = sparse_tf.indices[:, 0] + weights = np.take(inputs["weights"], ids) + return self.accumulate(indices, sparse_tf.values, weights) + + def accumulate_feature(self, output): + ''' + Wrapper around accumulate for trainer API. + Arguments: + output: + output of prediction of build_graph for calibrator + ''' + return self.accumulate(output['feature_ids'], output['feature_values'], output['weights']) + + def accumulate(self, feature_keys, feature_vals, weights=None): + '''Accumulate a single batch of feature keys, values and weights. + + These are accumulate until ``calibrate()`` is called. + + Arguments: + feature_keys: + 1D int64 array of feature keys. + feature_vals: + 1D float array of feature values. Each element of this array + maps to the commensurate element in ``feature_keys``. + weights: + Defaults to weights of 1. + 1D array containing the weights of each feature key, value pair. + Typically, this is the weight of each sample (but you still need + to provide one weight per key,value pair). + Each element of this array maps to the commensurate element in feature_keys. + ''' + if feature_keys.ndim != 1: + raise ValueError('Expecting 1D feature_keys, got %dD' % feature_keys.ndim) + if feature_vals.ndim != 1: + raise ValueError('Expecting 1D feature_values, got %dD' % feature_vals.ndim) + if feature_vals.size != feature_keys.size: + raise ValueError( + 'Expecting feature_keys.size == feature_values.size, got %d != %d' % + (feature_keys.size, feature_vals.size)) + if weights is not None: + weights = np.squeeze(weights) + if weights.ndim != 1: + raise ValueError('Expecting 1D weights, got %dD' % weights.ndim) + elif weights.size != feature_keys.size: + raise ValueError( + 'Expecting feature_keys.size == weights.size, got %d != %d' % + (feature_keys.size, weights.size)) + if weights is None: + weights = np.full(feature_vals.size, fill_value=DEFAULT_SAMPLE_WEIGHT) + unique_keys = np.unique(feature_keys) + for feature_id in unique_keys: + idx = np.where(feature_keys == feature_id) + if feature_id not in self._discretizer_feature_dict: + self._hash_map[feature_id] = self._hash_map_counter + # unlike v1, the hash_map_counter is incremented AFTER assignment. + # This makes the hash_map features zero-indexed: 0, 1, 2 instead of 1, 2, 3 + self._hash_map_counter += 1 + # creates a new cache if we never saw the feature before + discretizer_feature = PercentileDiscretizerFeature(feature_id) + self._discretizer_feature_dict[feature_id] = discretizer_feature + else: + discretizer_feature = self._discretizer_feature_dict[feature_id] + discretizer_feature.add_values({'values': feature_vals[idx], 'weights': weights[idx]}) + + def calibrate(self, debug=False): + ''' + Calibrates each PercentileDiscretizer feature after accumulation is complete. + + Arguments: + debug: + Boolean to request debug info be returned by the method. + (see Returns section below) + + The calibration results are stored in two matrices: + bin_ids: + 2D array of size number of accumulate ``features x n_bin+1``. + Contains the new IDs generated by PercentileDiscretizer. Each row maps to a feature. + Each row maps to different value bins. The IDs + are in the range ``1 -> bin_ids.size+1`` + bin_vals: + 2D array of the same size as bin_ids. + Each row maps to a feature. Each row contains the bin boundaries. + These boundaries represent feature values. + + Returns: + if debug is True, the method returns + + - 1D int64 array of feature_ids + - 2D float32 array copy of bin_vals (the bin boundaries) for each feature + - 2D int64 array of bin counts corresponding to the bin boundaries + + ''' + n_feature = len(self._discretizer_feature_dict) + if n_feature == 0 and not self._allow_empty_calibration: + raise RuntimeError("Need to accumulate some features for calibration\n" + "Likely, the calibration data is empty. This can\n" + "happen if the dataset is small, or if the following\n" + "cli args are set too low:\n" + " --discretizer_keep_rate (default=0.0008)\n" + " --discretizer_parts_downsampling_rate (default=0.2)\n" + "Consider increasing the values of these args.\n" + "To allow empty calibration data (and degenerate discretizer),\n" + "use the allow_empty_calibration input of the constructor.") + + self._bin_ids = np.arange(1, n_feature * (self._n_bin + 1) + 1) + self._bin_ids = self._bin_ids.reshape(n_feature, self._n_bin + 1) + + self._bin_vals.resize(n_feature, self._n_bin + 1) + + # buffers shared by PercentileDiscretizerFeature.calibrate() + percentile_indices = np.empty(self._n_bin + 1, dtype=np.float32) + + # Tensor from 0 to 1 in the number of steps provided + percentiles = np.linspace(0, 1, num=self._n_bin + 1, dtype=np.float32) + + if debug or self._bin_histogram: + debug_feature_ids = np.empty(n_feature, dtype=np.int64) + bin_counts = np.empty((n_feature, self._n_bin + 1), dtype=np.int64) + + # progress bar for calibration phase + progress_bar = tf.keras.utils.Progbar(n_feature) + + discretizer_features_dict = self._discretizer_feature_dict + for i, feature_id in enumerate(discretizer_features_dict): + if debug or self._bin_histogram: + debug_feature_ids[self._hash_map[feature_id]] = feature_id + bin_counts_buffer = bin_counts[self._hash_map[feature_id]] + else: + bin_counts_buffer = None + + # calibrate each PercentileDiscretizer feature (puts results in bin_vals) + discretizer_features_dict[feature_id].calibrate( + self._bin_vals[self._hash_map[feature_id]], # Gets feature-values + percentiles, percentile_indices, + bin_counts_buffer=bin_counts_buffer + ) + + # update progress bar 20 times + if (i % max(1.0, round(n_feature / 20)) == 0) or (i == n_feature - 1): + progress_bar.update(i + 1) + + super(PercentileDiscretizerCalibrator, self).calibrate() + + if self._bin_histogram: + # save bin histogram data for later + self._bin_histogram_dict = { + 'feature_ids': debug_feature_ids, + 'bin_counts': bin_counts, + 'bin_vals': self._bin_vals, + 'out_bits': self._out_bits, + } + + if debug: + return debug_feature_ids, self._bin_vals.copy(), bin_counts + + return None + + def _create_discretizer_layer(self, n_feature, hash_map_keys, hash_map_values, + feature_offsets, name): + return twml.layers.PercentileDiscretizer( + n_feature=n_feature, + n_bin=self._n_bin, + out_bits=self._out_bits, + bin_values=self._bin_vals.flatten(), + hash_keys=hash_map_keys, + hash_values=hash_map_values.astype(np.int64), + bin_ids=self._bin_ids.flatten().astype(np.int64), + feature_offsets=feature_offsets, + name=name, + **self._kwargs + ) + + def to_layer(self, name=None): + """ + Returns a twml.layers.PercentileDiscretizer Layer + that can be used for feature discretization. + + Arguments: + name: + name-scope of the PercentileDiscretizer layer + """ + n_feature = len(self._discretizer_feature_dict) + max_discretizer_feature = n_feature * (self._n_bin + 1) + + if not self._calibrated: + raise RuntimeError("Expecting prior call to calibrate()") + + if self._bin_ids.shape[0] != n_feature: + raise RuntimeError("Expecting self._bin_ids.shape[0] \ + != len(self._discretizer_feature_dict)") + if self._bin_vals.shape[0] != n_feature: + raise RuntimeError("Expecting self._bin_vals.shape[0] \ + != len(self._discretizer_feature_dict)") + + # can add at most #features * (n_bin+1) new feature ids + if 2**self._out_bits <= max_discretizer_feature: + raise ValueError("""Maximum number of features created by discretizer is + %d but requested that the output be limited to %d values (%d bits), + which is smaller than that. Please ensure the output has enough bits + to represent at least the new features""" + % (max_discretizer_feature, 2**self._out_bits, self._out_bits)) + + # build feature_offsets, hash_map_keys, hash_map_values + feature_offsets = np.arange(0, max_discretizer_feature, + self._n_bin + 1, dtype='int64') + hash_map_keys = np.array(list(self._hash_map.keys()), dtype=np.int64) + hash_map_values = np.array(list(self._hash_map.values()), dtype=np.float32) + + discretizer = self._create_discretizer_layer(n_feature, hash_map_keys, + hash_map_values, feature_offsets, name) + + return discretizer + + def get_layer_args(self): + ''' + Returns layer arguments required to implement multi-phase training. + See twml.calibrator.Calibrator.get_layer_args for more detailed documentation. + ''' + layer_args = { + 'n_feature': len(self._discretizer_feature_dict), + 'n_bin': self._n_bin, + 'out_bits': self._out_bits, + } + + return layer_args + + def add_hub_signatures(self, name): + """ + Add Hub Signatures for each calibrator + + Arguments: + name: + Calibrator name + """ + sparse_tf = tf.sparse_placeholder(tf.float32) + calibrator_layer = self.to_layer() + hub.add_signature( + inputs=sparse_tf, + outputs=calibrator_layer(sparse_tf, keep_inputs=False), + name=name) + + def write_summary(self, writer, sess=None): + """ + This method is called by save() to write a histogram of + PercentileDiscretizer feature bins to disk. A histogram is included for each + feature. + + Arguments: + writer: + tf.summary.FilteWriter instance. + used to add summaries to event files for inclusion in tensorboard. + sess: + tf.Session instance. Used to produces summaries for the writer. + """ + bin_counts_ph = tf.placeholder(tf.int64) + bin_counts = self._bin_histogram_dict['bin_counts'] + + # Record that distribution into a histogram summary + histo = tf.summary.histogram("discretizer_feature_bin_counts", bin_counts_ph) + for i in range(bin_counts.shape[0]): + bin_counts_summary = sess.run(histo, feed_dict={bin_counts_ph: bin_counts[i]}) + writer.add_summary(bin_counts_summary, global_step=i) + + def write_summary_json(self, save_dir, name="default"): + """ + Export bin information to HDFS. + + Arguments: + save_dir: + name of the saving directory. + name: + prefix of the saved hub signature. Default (string): "default". + """ + # Since the size is small: (# of bins) * (# of features), we always dump the file. + discretizer_export_bin_filename = os.path.join(save_dir, name + '_bin.json') + discretizer_export_bin_dict = { + 'feature_ids': self._bin_histogram_dict['feature_ids'].tolist(), + 'bin_boundaries': self._bin_histogram_dict['bin_vals'].tolist(), + 'output_bits': self._bin_histogram_dict['out_bits'] + } + twml.write_file(discretizer_export_bin_filename, discretizer_export_bin_dict, encode='json') + + def save(self, save_dir, name="default", verbose=False): + '''Save the calibrator into the given save_directory using TF Hub. + Arguments: + save_dir: + name of the saving directory. + name: + prefix of the saved hub signature. Default (string): "default". + ''' + if not self._calibrated: + raise RuntimeError("Expecting prior call to calibrate().Cannot save() prior to calibrate()") + + # This module allows for the calibrator to save be saved as part of + # Tensorflow Hub (this will allow it to be used in further steps) + def calibrator_module(): + # Note that this is usually expecting a sparse_placeholder + inputs = tf.sparse_placeholder(tf.float32) + calibrator_layer = self.to_layer() + # creates the signature to the calibrator module + hub.add_signature( + inputs=inputs, + outputs=calibrator_layer(inputs, keep_inputs=False), + name=name) + # and another signature for keep_inputs mode + hub.add_signature( + inputs=inputs, + outputs=calibrator_layer(inputs, keep_inputs=True), + name=name + '_keep_inputs') + + # exports the module to the save_dir + spec = hub.create_module_spec(calibrator_module) + with tf.Graph().as_default(): + module = hub.Module(spec) + with tf.Session() as session: + module.export(save_dir, session) + + self.write_summary_json(save_dir, name) diff --git a/twml/twml/contrib/eventbus/input_fn.py b/twml/twml/contrib/eventbus/input_fn.py new file mode 100644 index 000000000..c184d9434 --- /dev/null +++ b/twml/twml/contrib/eventbus/input_fn.py @@ -0,0 +1,59 @@ +from reader import EventBusPipedBinaryRecordReader +import tensorflow.compat.v1 as tf +import twml + + +""" +This module provides input function for DeepBird v2 training. +The training data records are loaded from an EventBus reader. +""" + + +def get_eventbus_data_record_generator(eventbus_reader): + """ + This module provides a data record generater from EventBus reader. + + Args: + eventbus_reader: EventBus reader + + Returns: + gen: Data record generater + """ + eventbus_reader.initialize() + counter = [0] + + def gen(): + while True: + record = eventbus_reader.read() + if eventbus_reader.debug: + tf.logging.warn("counter: {}".format(counter[0])) + with open('tmp_record_{}.bin'.format(counter[0]), 'wb') as f: + f.write(record) + counter[0] = counter[0] + 1 + yield record + return gen + + +def get_eventbus_data_record_dataset(eventbus_reader, parse_fn, batch_size): + """ + This module generates batch data for training from a data record generator. + """ + dataset = tf.data.Dataset.from_generator( + get_eventbus_data_record_generator(eventbus_reader), tf.string, tf.TensorShape([])) + return dataset.batch(batch_size).map(parse_fn, num_parallel_calls=4).prefetch(buffer_size=10) + + +def get_train_input_fn(feature_config, params, parse_fn=None): + """ + This module provides input function for DeepBird v2 training. + It gets batched training data from data record generator. + """ + eventbus_reader = EventBusPipedBinaryRecordReader( + params.jar_file, params.num_eb_threads, params.subscriber_id, + filter_str=params.filter_str, debug=params.debug) + + train_parse_fn = parse_fn or twml.parsers.get_sparse_parse_fn( + feature_config, ["ids", "keys", "values", "batch_size", "weights"]) + + return lambda: get_eventbus_data_record_dataset( + eventbus_reader, train_parse_fn, params.train_batch_size) diff --git a/twml/twml/contrib/eventbus/reader.py b/twml/twml/contrib/eventbus/reader.py new file mode 100644 index 000000000..2f8e2749e --- /dev/null +++ b/twml/twml/contrib/eventbus/reader.py @@ -0,0 +1,119 @@ +import io +import logging +import subprocess +from threading import Lock + +""" +This module provides a binary data record reader for EventBus data. +It starts a EventBus subscriber in a separate process to receive EventBus streaming data. +The subscriber is supposed to outputs received data through PIPE to this module. +This module parses input and output binary data record to serve as a record reader. +""" + + +class BinaryRecordReader(object): + def initialize(self): + pass + + def read(self): + """Read raw bytes for one record + """ + raise NotImplementedError + + def close(self): + pass + + +class ReadableWrapper(object): + def __init__(self, internal): + self.internal = internal + + def __getattr__(self, name): + return getattr(self.internal, name) + + def readable(self): + return True + + +class EventBusPipedBinaryRecordReader(BinaryRecordReader): + + JAVA = '/usr/lib/jvm/java-11-twitter/bin/java' + RECORD_SEPARATOR_HEX = [ + 0x29, 0xd8, 0xd5, 0x06, 0x58, 0xcd, 0x4c, 0x29, + 0xb2, 0xbc, 0x57, 0x99, 0x21, 0x71, 0xbd, 0xff + ] + RECORD_SEPARATOR = ''.join([chr(i) for i in RECORD_SEPARATOR_HEX]) + RECORD_SEPARATOR_LENGTH = len(RECORD_SEPARATOR) + CHUNK_SIZE = 8192 + + def __init__(self, jar_file, num_eb_threads, subscriber_id, + filter_str=None, buffer_size=32768, debug=False): + self.jar_file = jar_file + self.num_eb_threads = num_eb_threads + self.subscriber_id = subscriber_id + self.filter_str = filter_str if filter_str else '""' + self.buffer_size = buffer_size + self.lock = Lock() + self._pipe = None + self._buffered_reader = None + self._bytes_buffer = None + + self.debug = debug + + def initialize(self): + if not self._pipe: + self._pipe = subprocess.Popen( + [ + self.JAVA, '-jar', self.jar_file, + '-subscriberId', self.subscriber_id, + '-numThreads', str(self.num_eb_threads), + '-dataFilter', self.filter_str, + '-debug' if self.debug else '' + ], + stdout=subprocess.PIPE + ) + self._buffered_reader = io.BufferedReader( + ReadableWrapper(self._pipe.stdout), self.buffer_size) + self._bytes_buffer = io.BytesIO() + else: + logging.warning('Already initialized') + + def _find_next_record(self): + tail = [''] + while True: + chunk = tail[0] + self._buffered_reader.read(self.CHUNK_SIZE) + index = chunk.find(self.RECORD_SEPARATOR) + if index < 0: + self._bytes_buffer.write(chunk[:-self.RECORD_SEPARATOR_LENGTH]) + tail[0] = chunk[-self.RECORD_SEPARATOR_LENGTH:] + else: + self._bytes_buffer.write(chunk[:index]) + return chunk[(index + self.RECORD_SEPARATOR_LENGTH):] + + def _read(self): + with self.lock: + remaining = self._find_next_record() + record = self._bytes_buffer.getvalue() + # clean up buffer + self._bytes_buffer.close() + self._bytes_buffer = io.BytesIO() + self._bytes_buffer.write(remaining) + + return record + + def read(self): + while True: + try: + return self._read() + except Exception as e: + logging.error("Error reading bytes for next record: {}".format(e)) + if self.debug: + raise + + def close(self): + try: + self._bytes_buffer.close() + self._buffered_reader.close() + self._pipe.terminate() + except Exception as e: + logging.error("Error closing reader: {}".format(e)) diff --git a/twml/twml/contrib/export/__init__.py b/twml/twml/contrib/export/__init__.py new file mode 100644 index 000000000..99892dcfa --- /dev/null +++ b/twml/twml/contrib/export/__init__.py @@ -0,0 +1,2 @@ +from . import export_fn # noqa: F401 +from . import exporters # noqa: F401 diff --git a/twml/twml/contrib/export/export_fn.py b/twml/twml/contrib/export/export_fn.py new file mode 100644 index 000000000..6e59fff07 --- /dev/null +++ b/twml/twml/contrib/export/export_fn.py @@ -0,0 +1,264 @@ +""" +Functions for exporting models for different modes. +""" +from collections import OrderedDict +import os + +import tensorflow.compat.v1 as tf +from tensorflow.python.estimator.export import export +import twml +import yaml + + +def get_sparse_batch_supervised_input_receiver_fn(feature_config, keep_fields=None): + """Gets supervised_input_receiver_fn that decodes a BatchPredictionRequest as sparse tensors + with labels and weights as defined in feature_config. + This input_receiver_fn is required for exporting models with 'train' mode to be trained with + Java API + + Args: + feature_config (FeatureConfig): deepbird v2 feature config object + keep_fields (list): list of fields to keep + + Returns: + supervised_input_receiver_fn: input_receiver_fn used for train mode + """ + def supervised_input_receiver_fn(): + serialized_request = tf.placeholder(dtype=tf.uint8, name='request') + receiver_tensors = {'request': serialized_request} + + bpr = twml.contrib.readers.HashedBatchPredictionRequest(serialized_request, feature_config) + features = bpr.get_sparse_features() if keep_fields is None else bpr.get_features(keep_fields) + features['weights'] = bpr.weights + labels = bpr.labels + features, labels = bpr.apply_filter(features, labels) + + return export.SupervisedInputReceiver(features, labels, receiver_tensors) + + return supervised_input_receiver_fn + + +def update_build_graph_fn_for_train(build_graph_fn): + """Updates a build_graph_fn by inserting in graph output a serialized BatchPredictionResponse + similar to the export_output_fns for serving. + The key difference here is that + 1. We insert serialized BatchPredictionResponse in graph output with key 'prediction' instead of + creating an export_output object. This is because of the way estimators export model in 'train' + mode doesn't take custom export_output + 2. We only do it when `mode == 'train'` to avoid altering the graph when exporting + for 'infer' mode + + Args: + build_graph_fn (Callable): deepbird v2 build graph function + + Returns: + new_build_graph_fn: An updated build_graph_fn that inserts serialized BatchPredictResponse + to graph output when in 'train' mode + """ + def new_build_graph_fn(features, label, mode, params, config=None): + output = build_graph_fn(features, label, mode, params, config) + if mode == tf.estimator.ModeKeys.TRAIN: + output.update( + twml.export_output_fns.batch_prediction_continuous_output_fn(output)[ + tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs + ) + return output + return new_build_graph_fn + + +def export_model_for_train_and_infer( + trainer, feature_config, keep_fields, export_dir, as_text=False): + """Function for exporting model with both 'train' and 'infer' mode. + + This means the exported saved_model.pb will contain two meta graphs, one with tag 'train' + and the other with tag 'serve', and it can be loaded in Java API with either tag depending on + the use case + + Args: + trainer (DataRecordTrainer): deepbird v2 DataRecordTrainer + feature_config (FeatureConfig): deepbird v2 feature config + keep_fields (list of string): list of field keys, e.g. + ('ids', 'keys', 'values', 'batch_size', 'total_size', 'codes') + export_dir (str): a directory (local or hdfs) to export model to + as_text (bool): if True, write 'saved_model.pb' as binary file, else write + 'saved_model.pbtxt' as human readable text file. Default False + """ + train_input_receiver_fn = get_sparse_batch_supervised_input_receiver_fn( + feature_config, keep_fields) + predict_input_receiver_fn = twml.parsers.get_sparse_serving_input_receiver_fn( + feature_config, keep_fields) + trainer._export_output_fn = twml.export_output_fns.batch_prediction_continuous_output_fn + trainer._build_graph_fn = update_build_graph_fn_for_train(trainer._build_graph_fn) + trainer._estimator._export_all_saved_models( + export_dir_base=export_dir, + input_receiver_fn_map={ + tf.estimator.ModeKeys.TRAIN: train_input_receiver_fn, + tf.estimator.ModeKeys.PREDICT: predict_input_receiver_fn + }, + as_text=as_text, + ) + + trainer.export_model_effects(export_dir) + + +def export_all_models_with_receivers(estimator, export_dir, + train_input_receiver_fn, + eval_input_receiver_fn, + predict_input_receiver_fn, + export_output_fn, + export_modes=('train', 'eval', 'predict'), + register_model_fn=None, + feature_spec=None, + checkpoint_path=None, + log_features=True): + """ + Function for exporting a model with train, eval, and infer modes. + + Args: + estimator: + Should be of type tf.estimator.Estimator. + You can get this from trainer using trainer.estimator + export_dir: + Directory to export the model. + train_input_receiver_fn: + Input receiver for train interface. + eval_input_receiver_fn: + Input receiver for eval interface. + predict_input_receiver_fn: + Input receiver for predict interface. + export_output_fn: + export_output_fn to be used for serving. + export_modes: + A list to Specify what modes to export. Can be "train", "eval", "predict". + Defaults to ["train", "eval", "predict"] + register_model_fn: + An optional function which is called with export_dir after models are exported. + Defaults to None. + Returns: + The timestamped directory the models are exported to. + """ + # TODO: Fix for hogwild / distributed training. + + if export_dir is None: + raise ValueError("export_dir can not be None") + export_dir = twml.util.sanitize_hdfs_path(export_dir) + input_receiver_fn_map = {} + + if "train" in export_modes: + input_receiver_fn_map[tf.estimator.ModeKeys.TRAIN] = train_input_receiver_fn + + if "eval" in export_modes: + input_receiver_fn_map[tf.estimator.ModeKeys.EVAL] = eval_input_receiver_fn + + if "predict" in export_modes: + input_receiver_fn_map[tf.estimator.ModeKeys.PREDICT] = predict_input_receiver_fn + + export_dir = estimator._export_all_saved_models( + export_dir_base=export_dir, + input_receiver_fn_map=input_receiver_fn_map, + checkpoint_path=checkpoint_path, + ) + + if register_model_fn is not None: + register_model_fn(export_dir, feature_spec, log_features) + + return export_dir + + +def export_all_models(trainer, + export_dir, + parse_fn, + serving_input_receiver_fn, + export_output_fn=None, + export_modes=('train', 'eval', 'predict'), + feature_spec=None, + checkpoint=None, + log_features=True): + """ + Function for exporting a model with train, eval, and infer modes. + + Args: + trainer: + An object of type twml.trainers.Trainer. + export_dir: + Directory to export the model. + parse_fn: + The parse function used parse the inputs for train and eval. + serving_input_receiver_fn: + The input receiver function used during serving. + export_output_fn: + export_output_fn to be used for serving. + export_modes: + A list to Specify what modes to export. Can be "train", "eval", "predict". + Defaults to ["train", "eval", "predict"] + feature_spec: + A dictionary obtained from FeatureConfig.get_feature_spec() to serialize + as feature_spec.yaml in export_dir. + Defaults to None + Returns: + The timestamped directory the models are exported to. + """ + # Only export from chief in hogwild or distributed modes. + if trainer.params.get('distributed', False) and not trainer.estimator.config.is_chief: + tf.logging.info("Trainer.export_model ignored due to instance not being chief.") + return + + if feature_spec is None: + if getattr(trainer, '_feature_config') is None: + raise ValueError("feature_spec is set to None." + "Please pass feature_spec=feature_config.get_feature_spec() to the export_all_model function") + else: + feature_spec = trainer._feature_config.get_feature_spec() + + export_dir = twml.util.sanitize_hdfs_path(export_dir) + old_export_output_fn = trainer._export_output_fn + trainer._export_output_fn = export_output_fn + supervised_input_receiver_fn = twml.parsers.convert_to_supervised_input_receiver_fn(parse_fn) + if not checkpoint: + checkpoint = trainer.best_or_latest_checkpoint + + export_dir = export_all_models_with_receivers(estimator=trainer.estimator, + export_dir=export_dir, + train_input_receiver_fn=supervised_input_receiver_fn, + eval_input_receiver_fn=supervised_input_receiver_fn, + predict_input_receiver_fn=serving_input_receiver_fn, + export_output_fn=export_output_fn, + export_modes=export_modes, + register_model_fn=trainer.export_model_effects, + feature_spec=feature_spec, + checkpoint_path=checkpoint, + log_features=log_features) + trainer._export_output_fn = old_export_output_fn + return export_dir + + +def export_feature_spec(dir_path, feature_spec_dict): + """ + Exports a FeatureConfig.get_feature_spec() dict to /feature_spec.yaml. + """ + def ordered_dict_representer(dumper, data): + return dumper.represent_mapping('tag:yaml.org,2002:map', data.items()) + + try: + # needed for Python 2 + yaml.add_representer(str, yaml.representer.SafeRepresenter.represent_str) + yaml.add_representer(unicode, yaml.representer.SafeRepresenter.represent_unicode) + except NameError: + # 'unicode' type doesn't exist on Python 3 + # PyYAML handles unicode correctly in Python 3 + pass + + yaml.add_representer(OrderedDict, ordered_dict_representer) + + fbase = "feature_spec.yaml" + fname = fbase.encode('utf-8') if type(dir_path) != str else fbase + file_path = os.path.join(dir_path, fname) + with tf.io.gfile.GFile(file_path, mode='w') as f: + yaml.dump(feature_spec_dict, f, default_flow_style=False, allow_unicode=True) + tf.logging.info("Exported feature spec to %s" % file_path) + + return file_path + + +# Keep the alias for compatibility. +get_supervised_input_receiver_fn = twml.parsers.convert_to_supervised_input_receiver_fn diff --git a/twml/twml/contrib/export/exporters.py b/twml/twml/contrib/export/exporters.py new file mode 100644 index 000000000..122955cbc --- /dev/null +++ b/twml/twml/contrib/export/exporters.py @@ -0,0 +1,145 @@ +""" +Wrappers around tf.estimator.Exporters to export models and save checkpoints. +""" +import os + +import tensorflow.compat.v1 as tf +from tensorflow.python.estimator import exporter +import twml + + +class _AllSavedModelsExporter(tf.estimator.Exporter): + """Internal exporter class to be used for exporting models for different modes.""" + + def __init__(self, + name, + input_receiver_fn_map, + backup_checkpoints, + assets_extra=None, + as_text=False): + """ + Args: + name: A unique name to be used for the exporter. This is used in the export path. + input_receiver_fn_map: A map of tf.estimator.ModeKeys to input_receiver_fns. + backup_checkpoints: A flag to specify if backups of checkpoints need to be made. + assets_extra: Additional assets to be included in the exported model. + as_text: Specifies if the exported model should be in a human readable text format. + """ + self._name = name + self._input_receiver_fn_map = input_receiver_fn_map + self._backup_checkpoints = backup_checkpoints + self._assets_extra = assets_extra + self._as_text = as_text + + @property + def name(self): + return self._name + + def export(self, estimator, export_path, checkpoint_path, eval_result, + is_the_final_export): + del is_the_final_export + + export_path = twml.util.sanitize_hdfs_path(export_path) + checkpoint_path = twml.util.sanitize_hdfs_path(checkpoint_path) + + if self._backup_checkpoints: + backup_path = os.path.join(export_path, "checkpoints") + # Ensure backup_path is created. makedirs passes if dir already exists. + tf.io.gfile.makedirs(backup_path) + twml.util.backup_checkpoint(checkpoint_path, backup_path, empty_backup=False) + + export_result = estimator.experimental_export_all_saved_models( + export_path, + self._input_receiver_fn_map, + assets_extra=self._assets_extra, + as_text=self._as_text, + checkpoint_path=checkpoint_path) + + return export_result + + +class BestExporter(tf.estimator.BestExporter): + """ + This class inherits from tf.estimator.BestExporter with the following differences: + - It also creates a backup of the best checkpoint. + - It can export the model for multiple modes. + + A backup / export is performed everytime the evaluated metric is better + than previous models. + """ + + def __init__(self, + name='best_exporter', + input_receiver_fn_map=None, + backup_checkpoints=True, + event_file_pattern='eval/*.tfevents.*', + compare_fn=exporter._loss_smaller, + assets_extra=None, + as_text=False, + exports_to_keep=5): + """ + Args: + name: A unique name to be used for the exporter. This is used in the export path. + input_receiver_fn_map: A map of tf.estimator.ModeKeys to input_receiver_fns. + backup_checkpoints: A flag to specify if backups of checkpoints need to be made. + + Note: + Check the following documentation for more information about the remaining args: + https://www.tensorflow.org/api_docs/python/tf/estimator/BestExporter + """ + serving_input_receiver_fn = input_receiver_fn_map.get(tf.estimator.ModeKeys.PREDICT) + + super(BestExporter, self).__init__( + name, serving_input_receiver_fn, event_file_pattern, compare_fn, + assets_extra, as_text, exports_to_keep) + + if not hasattr(self, "_saved_model_exporter"): + raise AttributeError( + "_saved_model_exporter needs to exist for this exporter to work." + " This is potentially broken because of an internal change in Tensorflow") + + # Override the saved_model_exporter with SaveAllmodelsexporter + self._saved_model_exporter = _AllSavedModelsExporter( + name, input_receiver_fn_map, backup_checkpoints, assets_extra, as_text) + + +class LatestExporter(tf.estimator.LatestExporter): + """ + This class inherits from tf.estimator.LatestExporter with the following differences: + - It also creates a backup of the latest checkpoint. + - It can export the model for multiple modes. + + A backup / export is performed everytime the evaluated metric is better + than previous models. + """ + + def __init__(self, + name='latest_exporter', + input_receiver_fn_map=None, + backup_checkpoints=True, + assets_extra=None, + as_text=False, + exports_to_keep=5): + """ + Args: + name: A unique name to be used for the exporter. This is used in the export path. + input_receiver_fn_map: A map of tf.estimator.ModeKeys to input_receiver_fns. + backup_checkpoints: A flag to specify if backups of checkpoints need to be made. + + Note: + Check the following documentation for more information about the remaining args: + https://www.tensorflow.org/api_docs/python/tf/estimator/LatestExporter + """ + serving_input_receiver_fn = input_receiver_fn_map.get(tf.estimator.ModeKeys.PREDICT) + + super(LatestExporter, self).__init__( + name, serving_input_receiver_fn, assets_extra, as_text, exports_to_keep) + + if not hasattr(self, "_saved_model_exporter"): + raise AttributeError( + "_saved_model_exporter needs to exist for this exporter to work." + " This is potentially broken because of an internal change in Tensorflow") + + # Override the saved_model_exporter with SaveAllmodelsexporter + self._saved_model_exporter = _AllSavedModelsExporter( + name, input_receiver_fn_map, backup_checkpoints, assets_extra, as_text) diff --git a/twml/twml/contrib/feature_config.py b/twml/twml/contrib/feature_config.py new file mode 100644 index 000000000..833695751 --- /dev/null +++ b/twml/twml/contrib/feature_config.py @@ -0,0 +1,85 @@ +""" +Feature configuration for DeepBird jobs returns dictionary of sparse and dense Features +""" +from twitter.deepbird.io.legacy.contrib import feature_config +import twml + + +class FeatureConfig(feature_config.FeatureConfig): + def get_feature_spec(self): + """ + Generates a serialization-friendly dict representing this FeatureConfig. + """ + doc = super(FeatureConfig, self).get_feature_spec() + + # Override the class in the spec. + doc["class"] = "twml.contrib.FeatureConfig" + + return doc + + +class FeatureConfigBuilder(feature_config.FeatureConfigBuilder): + # Overwrite self.build() to return twml.FeatureConfig instead + def build(self): + """ + Returns an instance of FeatureConfig with the features passed to the FeatureConfigBuilder. + """ + + ( + keep_tensors, + keep_sparse_tensors, + feature_map, + features_add, + feature_name_to_feature_parser, + feature_in_bq_name, + ) = self._build() + + discretize_dict = {} + for config in self._sparse_extraction_configs: + if config.discretize_num_bins and config.discretize_output_size_bits: + if config.discretize_type == "percentile": + calibrator = twml.contrib.calibrators.PercentileDiscretizerCalibrator + elif config.discretize_type == "hashed_percentile": + calibrator = twml.contrib.calibrators.HashedPercentileDiscretizerCalibrator + elif config.discretize_type == "hashing": + calibrator = twml.contrib.calibrators.HashingDiscretizerCalibrator + else: + raise ValueError("Unsupported discretizer type: " + config.discretize_type) + discretize_dict[config.output_name] = calibrator( + config.discretize_num_bins, + config.discretize_output_size_bits, + allow_empty_calibration=config.allow_empty_calibration, + ) + elif config.discretize_num_bins or config.discretize_output_size_bits: + raise ValueError( + "Discretize_num_bins AND discretize_output_size_bits need to be in the FeatureConfig" + ) + + return FeatureConfig( + features={}, + labels=self._labels, + weight=self._weight, + filters=self._filter_features, + tensor_types=keep_tensors, + sparse_tensor_types=keep_sparse_tensors, + feature_types=feature_map, + sparse_extraction_configs=self._sparse_extraction_configs, + feature_extraction_configs=self._feature_extraction_configs, + feature_group_extraction_configs=self._feature_group_extraction_configs, + image_configs=self._image_configs, + discretize_config=discretize_dict, + feature_ids=features_add, + decode_mode=self._decode_mode, + legacy_sparse=self._legacy_sparse, + feature_name_to_feature_parser=feature_name_to_feature_parser, + feature_in_bq_name=feature_in_bq_name, + ) + + +TensorExtractionConfig = feature_config.TensorExtractionConfig + +FeatureGroupExtractionConfig = feature_config.FeatureGroupExtractionConfig + +ImageExtractionConfig = feature_config.ImageExtractionConfig + +_set_tensor_namedtuple = feature_config._set_tensor_namedtuple diff --git a/twml/twml/contrib/feature_config_parsers.py b/twml/twml/contrib/feature_config_parsers.py new file mode 100644 index 000000000..83c402e2e --- /dev/null +++ b/twml/twml/contrib/feature_config_parsers.py @@ -0,0 +1,224 @@ +"""Utility functions to create FeatureConfig objects from feature_spec.yaml files""" +import os +import re + +import tensorflow.compat.v1 as tf +import yaml +from twml.feature_config import FeatureConfigBuilder +from twml.contrib.feature_config import FeatureConfigBuilder as FeatureConfigBuilderV2 + + +def _get_config_version(config_dict): + doc = config_dict + supported_classes = { + "twml.FeatureConfig": "v1", + "twml.contrib.FeatureConfig": "v2" + } + if "class" not in doc: + raise ValueError("'class' key not found") + if doc["class"] not in supported_classes.keys(): + raise ValueError("Class %s not supported. Supported clases are %s" + % (doc["class"], supported_classes.keys())) + return supported_classes[doc["class"]] + + +def _validate_config_dict_v1(config_dict): + """ + Validate spec exported by twml.FeatureConfig + """ + doc = config_dict + + def malformed_error(msg): + raise ValueError("twml.FeatureConfig: Malformed feature_spec. %s" % msg) + + if doc["class"] != "twml.FeatureConfig": + malformed_error("'class' is not twml.FeatureConfig") + if "format" not in doc: + malformed_error("'format' key not found") + + # validate spec exported by twml.FeatureConfig + if doc["format"] == "exported": + dict_keys = ["features", "labels", "weight", "tensors", "sparse_tensors"] + for key in dict_keys: + if key not in doc: + malformed_error("'%s' key not found" % key) + if type(doc[key]) != dict: + malformed_error("'%s' is not a dict" % key) + if "filters" not in doc: + malformed_error("'filters' key not found") + elif type(doc["filters"]) != list: + malformed_error("'filters' is not a list") + + # validate spec provided by modeler + elif doc["format"] == "manual": + raise NotImplementedError("Manual config support not yet implemented") + else: + malformed_error("'format' must be 'exported' or 'manual'") + + +def _validate_config_dict_v2(config_dict): + """ + Validate spec exported by twml.contrib.FeatureConfig + """ + doc = config_dict + + def malformed_error(msg): + raise ValueError("twml.contrib.FeatureConfig: Malformed feature_spec. %s" % msg) + + if doc["class"] != "twml.contrib.FeatureConfig": + malformed_error("'class' is not twml.contrib.FeatureConfig") + if "format" not in doc: + malformed_error("'format key not found'") + + # validate spec exported by twml.contrib.FeatureConfig (basic validation only) + if doc["format"] == "exported": + dict_keys = ["features", "labels", "weight", "tensors", "sparseTensors", "discretizeConfig"] + for key in dict_keys: + if key not in doc: + malformed_error("'%s' key not found" % key) + if type(doc[key]) != dict: + malformed_error("'%s' is not a dict" % key) + list_keys = ["sparseFeatureGroups", "denseFeatureGroups", "denseFeatures", "images", "filters"] + for key in list_keys: + if key not in doc: + malformed_error("'%s' key not found" % key) + if type(doc[key]) != list: + malformed_error("'%s' is not a list" % key) + + # validate spec provided by modeler + elif doc["format"] == "manual": + raise NotImplementedError("Manual config support not yet implemented") + else: + malformed_error("'format' must be 'exported' or 'manual'") + + +def _create_feature_config_v1(config_dict, data_spec_path): + fc_builder = FeatureConfigBuilder(data_spec_path) + + if config_dict["format"] == "exported": + # add features + for feature_info in config_dict["features"].values(): + feature_name = re.escape(feature_info["featureName"]) + feature_group = feature_info["featureGroup"] + fc_builder.add_feature(feature_name, feature_group) + # add labels + labels = [] + for label_info in config_dict["labels"].values(): + labels.append(label_info["featureName"]) + fc_builder.add_labels(labels) + # feature filters + for feature_name in config_dict["filters"]: + fc_builder.add_filter(feature_name) + # weight + if config_dict["weight"]: + weight_feature = list(config_dict["weight"].values())[0]["featureName"] + fc_builder.define_weight(weight_feature) + else: + raise ValueError("Format '%s' not implemented" % config_dict["format"]) + + return fc_builder.build() + + +def _create_feature_config_v2(config_dict, data_spec_path): + fc_builder = FeatureConfigBuilderV2(data_spec_path) + + if config_dict["format"] == "exported": + # add sparse group extraction configs + for sparse_group in config_dict["sparseFeatureGroups"]: + fids = sparse_group["features"].keys() + fnames = [sparse_group["features"][fid]["featureName"] for fid in fids] + fc_builder.extract_features_as_hashed_sparse( + feature_regexes=[re.escape(fname) for fname in fnames], + output_tensor_name=sparse_group["outputName"], + hash_space_size_bits=sparse_group["hashSpaceBits"], + discretize_num_bins=sparse_group["discretize"]["numBins"], + discretize_output_size_bits=sparse_group["discretize"]["outputSizeBits"], + discretize_type=sparse_group["discretize"]["type"], + type_filter=sparse_group["filterType"]) + + # add dense group extraction configs + for dense_group in config_dict["denseFeatureGroups"]: + fids = dense_group["features"].keys() + fnames = [dense_group["features"][fid]["featureName"] for fid in fids] + fc_builder.extract_feature_group( + feature_regexes=[re.escape(fname) for fname in fnames], + group_name=dense_group["outputName"], + type_filter=dense_group["filterType"], + default_value=dense_group["defaultValue"]) + + # add dense feature configs + for dense_features in config_dict["denseFeatures"]: + fids = dense_features["features"].keys() + fnames = [dense_features["features"][fid]["featureName"] for fid in fids] + default_value = dense_features["defaultValue"] + if len(fnames) == 1 and type(default_value) != dict: + fc_builder.extract_feature( + feature_name=re.escape(fnames[0]), + expected_shape=dense_features["expectedShape"], + default_value=dense_features["defaultValue"]) + else: + fc_builder.extract_features( + feature_regexes=[re.escape(fname) for fname in fnames], + default_value_map=dense_features["defaultValue"]) + + # add image feature configs + for image in config_dict["images"]: + fc_builder.extract_image( + feature_name=image["featureName"], + preprocess=image["preprocess"], + out_type=tf.as_dtype(image["outType"].lower()), + channels=image["channels"], + default_image=image["defaultImage"], + ) + + # add other tensor features (non-image) + tensor_fnames = [] + image_fnames = [img["featureName"] for img in config_dict["images"]] + for tensor_fname in config_dict["tensors"]: + if tensor_fname not in image_fnames: + tensor_fnames.append(tensor_fname) + for sparse_tensor_fname in config_dict["sparseTensors"]: + tensor_fnames.append(sparse_tensor_fname) + fc_builder.extract_tensors(tensor_fnames) + + # add labels + labels = [] + for label_info in config_dict["labels"].values(): + labels.append(label_info["featureName"]) + fc_builder.add_labels(labels) + + else: + raise ValueError("Format '%s' not implemented" % config_dict["format"]) + + return fc_builder.build() + + +def create_feature_config_from_dict(config_dict, data_spec_path): + """ + Create a FeatureConfig object from a feature spec dict. + """ + config_version = _get_config_version(config_dict) + if config_version == "v1": + _validate_config_dict_v1(config_dict) + feature_config = _create_feature_config_v1(config_dict, data_spec_path) + elif config_version == "v2": + _validate_config_dict_v2(config_dict) + feature_config = _create_feature_config_v2(config_dict, data_spec_path) + else: + raise ValueError("version not supported") + + return feature_config + + +def create_feature_config(config_path, data_spec_path): + """ + Create a FeatureConfig object from a feature_spec.yaml file. + """ + _, ext = os.path.splitext(config_path) + if ext not in ['.yaml', '.yml']: + raise ValueError("create_feature_config_from_yaml: Only .yaml/.yml supported") + + with tf.io.gfile.GFile(config_path, mode='r') as fs: + config_dict = yaml.safe_load(fs) + + return create_feature_config_from_dict(config_dict, data_spec_path) diff --git a/twml/twml/contrib/feature_importances/__init__.py b/twml/twml/contrib/feature_importances/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/twml/twml/contrib/feature_importances/feature_importances.py b/twml/twml/contrib/feature_importances/feature_importances.py new file mode 100644 index 000000000..a8bfcc129 --- /dev/null +++ b/twml/twml/contrib/feature_importances/feature_importances.py @@ -0,0 +1,414 @@ +# checkstyle: noqa + +import time +from collections import defaultdict + +from com.twitter.mlmetastore.modelrepo.client import ModelRepoClient +from com.twitter.mlmetastore.modelrepo.core import FeatureImportance, FeatureNames +from twitter.deepbird.io.util import match_feature_regex_list + +from twml.contrib.feature_importances.helpers import ( + _get_feature_name_from_config, + _get_feature_types_from_records, + _get_metrics_hook, + _expand_prefix, + longest_common_prefix, + write_list_to_hdfs_gfile) +from twml.contrib.feature_importances.feature_permutation import PermutedInputFnFactory +from twml.tracking import ExperimentTracker + +from tensorflow.compat.v1 import logging +from requests.exceptions import HTTPError, RetryError +from queue import Queue + + +SERIAL = "serial" +TREE = "tree" +INDIVIDUAL = "Individual" +GROUP = "Group" +ROC_AUC = "roc_auc" +RCE = "rce" +LOSS = "loss" + + +def _repartition(feature_list_queue, fnames_ftypes, split_feature_group_on_period): + """ + Iterate through letters to partition each feature by prefix, and then put each tuple + (prefix, feature_partition) into the feature_list_queue + Args: + prefix (str): The prefix shared by each feature in list_of_feature_types + feature_list_queue (Queue<(str, list<(str, str)>)>): The queue of feature groups + fnames_ftypes (list<(str, str)>): List of (fname, ftype) pairs. Each fname begins with prefix + split_feature_group_on_period (str): If true, require that feature groups end in a period + Returns: + Updated queue with each group in fnames_ftypes + """ + assert len(fnames_ftypes) > 1 + + split_character = "." if split_feature_group_on_period else None + # Compute the longest prefix of the words + prefix = longest_common_prefix( + strings=[fname for fname, _ in fnames_ftypes], split_character=split_character) + + # Separate the features by prefix + prefix_to_features = defaultdict(list) + for fname, ftype in fnames_ftypes: + assert fname.startswith(prefix) + new_prefix = _expand_prefix(fname=fname, prefix=prefix, split_character=split_character) + prefix_to_features[new_prefix].append((fname, ftype)) + + # Add all of the new partitions to the queue + for new_prefix, fname_ftype_list in prefix_to_features.items(): + extended_new_prefix = longest_common_prefix( + strings=[fname for fname, _ in fname_ftype_list], split_character=split_character) + assert extended_new_prefix.startswith(new_prefix) + feature_list_queue.put((extended_new_prefix, fname_ftype_list)) + return feature_list_queue + + +def _infer_if_is_metric_larger_the_better(stopping_metric): + # Infers whether a metric should be interpreted such that larger numbers are better (e.g. ROC_AUC), as opposed to + # larger numbers being worse (e.g. LOSS) + if stopping_metric is None: + raise ValueError("Error: Stopping Metric cannot be None") + elif stopping_metric.startswith(LOSS): + logging.info("Interpreting {} to be a metric where larger numbers are worse".format(stopping_metric)) + is_metric_larger_the_better = False + else: + logging.info("Interpreting {} to be a metric where larger numbers are better".format(stopping_metric)) + is_metric_larger_the_better = True + return is_metric_larger_the_better + + +def _check_whether_tree_should_expand(baseline_performance, computed_performance, sensitivity, stopping_metric, is_metric_larger_the_better): + """ + Returns True if + - the metric is positive (e.g. ROC_AUC) and computed_performance is nontrivially smaller than the baseline_performance + - the metric is negative (e.g. LOSS) and computed_performance is nontrivially larger than the baseline_performance + """ + difference = ((baseline_performance[stopping_metric] - computed_performance[stopping_metric]) / + baseline_performance[stopping_metric]) + + if not is_metric_larger_the_better: + difference = -difference + + logging.info( + "Found a {} difference of {}. Sensitivity is {}.".format("positive" if is_metric_larger_the_better else "negative", difference, sensitivity)) + return difference > sensitivity + + +def _compute_multiple_permuted_performances_from_trainer( + factory, fname_ftypes, trainer, parse_fn, record_count): + """Compute performances with fname and fype permuted + """ + metrics_hook = _get_metrics_hook(trainer) + trainer._estimator.evaluate( + input_fn=factory.get_permuted_input_fn( + batch_size=trainer._params.eval_batch_size, parse_fn=parse_fn, fname_ftypes=fname_ftypes), + steps=(record_count + trainer._params.eval_batch_size) // trainer._params.eval_batch_size, + hooks=[metrics_hook], + checkpoint_path=trainer.best_or_latest_checkpoint) + return metrics_hook.metric_values + + +def _get_extra_feature_group_performances(factory, trainer, parse_fn, extra_groups, feature_to_type, record_count): + """Compute performance differences for the extra feature groups + """ + extra_group_feature_performance_results = {} + for group_name, raw_feature_regex_list in extra_groups.items(): + start = time.time() + fnames = match_feature_regex_list( + features=feature_to_type.keys(), + feature_regex_list=[regex for regex in raw_feature_regex_list], + preprocess=False, + as_dict=False) + + fnames_ftypes = [(fname, feature_to_type[fname]) for fname in fnames] + + logging.info("Extracted extra group {} with features {}".format(group_name, fnames_ftypes)) + extra_group_feature_performance_results[group_name] = _compute_multiple_permuted_performances_from_trainer( + factory=factory, fname_ftypes=fnames_ftypes, + trainer=trainer, parse_fn=parse_fn, record_count=record_count) + logging.info("\n\nImportances computed for {} in {} seconds \n\n".format( + group_name, int(time.time() - start))) + return extra_group_feature_performance_results + + +def _feature_importances_tree_algorithm( + data_dir, trainer, parse_fn, fnames, stopping_metric, file_list=None, datarecord_filter_fn=None, split_feature_group_on_period=True, + record_count=99999, is_metric_larger_the_better=None, sensitivity=0.025, extra_groups=None, dont_build_tree=False): + """Tree algorithm for feature and feature group importances. This algorithm build a prefix tree of + the feature names and then traverses the tree with a BFS. At each node (aka group of features with + a shared prefix) the algorithm computes the performance of the model when we permute all features + in the group. The algorithm only zooms-in on groups that impact the performance by more than + sensitivity. As a result, features that affect the model performance by less than sensitivity will + not have an exact importance. + Args: + data_dir: (str): The location of the training or testing data to compute importances over. + If None, the trainer._eval_files are used + trainer: (DataRecordTrainer): A DataRecordTrainer object + parse_fn: (function): The parse_fn used by eval_input_fn + fnames (list): The list of feature names + stopping_metric (str): The metric to use to determine when to stop expanding trees + file_list (list): The list of filenames. Exactly one of file_list and data_dir should be + provided + datarecord_filter_fn (function): a function takes a single data sample in com.twitter.ml.api.ttypes.DataRecord format + and return a boolean value, to indicate if this data record should be kept in feature importance module or not. + split_feature_group_on_period (boolean): If true, split feature groups by period rather than on + optimal prefix + record_count (int): The number of records to compute importances over + is_metric_larger_the_better (boolean): If true, assume that stopping_metric is a metric where larger + values are better (e.g. ROC-AUC) + sensitivity (float): The smallest change in performance to continue to expand the tree + extra_groups (dict>): A dictionary mapping the name of extra feature groups to the list of + the names of the features in the group. You should only supply a value for this argument if you have a set + of features that you want to evaluate as a group but don't share a prefix + dont_build_tree (boolean): If True, don't build the tree and only compute the extra_groups importances + Returns: + A dictionary that contains the individual and group feature importances + """ + factory = PermutedInputFnFactory( + data_dir=data_dir, record_count=record_count, file_list=file_list, datarecord_filter_fn=datarecord_filter_fn) + baseline_performance = _compute_multiple_permuted_performances_from_trainer( + factory=factory, fname_ftypes=[], + trainer=trainer, parse_fn=parse_fn, record_count=record_count) + out = {"None": baseline_performance} + + if stopping_metric not in baseline_performance: + raise ValueError("The stopping metric '{}' not found in baseline_performance. Metrics are {}".format( + stopping_metric, list(baseline_performance.keys()))) + + is_metric_larger_the_better = ( + is_metric_larger_the_better if is_metric_larger_the_better is not None else _infer_if_is_metric_larger_the_better(stopping_metric)) + logging.info("Using {} as the stopping metric for the tree algorithm".format(stopping_metric)) + + feature_to_type = _get_feature_types_from_records(records=factory.records, fnames=fnames) + all_feature_types = list(feature_to_type.items()) + + individual_feature_performances = {} + feature_group_performances = {} + if dont_build_tree: + logging.info("Not building feature importance trie. Will only compute importances for the extra_groups") + else: + logging.info("Building feature importance trie") + # Each element in the Queue will be a tuple of (prefix, list_of_feature_type_pairs) where + # each feature in list_of_feature_type_pairs will have have the prefix "prefix" + feature_list_queue = _repartition( + feature_list_queue=Queue(), fnames_ftypes=all_feature_types, split_feature_group_on_period=split_feature_group_on_period) + + while not feature_list_queue.empty(): + # Pop the queue. We should never have an empty list in the queue + prefix, fnames_ftypes = feature_list_queue.get() + assert len(fnames_ftypes) > 0 + + # Compute performance from permuting all features in fname_ftypes + logging.info( + "\n\nComputing importances for {} ({}...). {} elements left in the queue \n\n".format( + prefix, fnames_ftypes[:5], feature_list_queue.qsize())) + start = time.time() + computed_performance = _compute_multiple_permuted_performances_from_trainer( + factory=factory, fname_ftypes=fnames_ftypes, + trainer=trainer, parse_fn=parse_fn, record_count=record_count) + logging.info("\n\nImportances computed for {} in {} seconds \n\n".format( + prefix, int(time.time() - start))) + if len(fnames_ftypes) == 1: + individual_feature_performances[fnames_ftypes[0][0]] = computed_performance + else: + feature_group_performances[prefix] = computed_performance + # Dig deeper into the features in fname_ftypes only if there is more than one feature in the + # list and the performance drop is nontrivial + logging.info("Checking performance for {} ({}...)".format(prefix, fnames_ftypes[:5])) + check = _check_whether_tree_should_expand( + baseline_performance=baseline_performance, computed_performance=computed_performance, + sensitivity=sensitivity, stopping_metric=stopping_metric, is_metric_larger_the_better=is_metric_larger_the_better) + if len(fnames_ftypes) > 1 and check: + logging.info("Expanding {} ({}...)".format(prefix, fnames_ftypes[:5])) + feature_list_queue = _repartition( + feature_list_queue=feature_list_queue, fnames_ftypes=fnames_ftypes, split_feature_group_on_period=split_feature_group_on_period) + else: + logging.info("Not expanding {} ({}...)".format(prefix, fnames_ftypes[:5])) + + # Baseline performance is grouped in with individual_feature_importance_results + individual_feature_performance_results = dict( + out, **{k: v for k, v in individual_feature_performances.items()}) + group_feature_performance_results = {k: v for k, v in feature_group_performances.items()} + + if extra_groups is not None: + logging.info("Computing performances for extra groups {}".format(extra_groups.keys())) + for group_name, performances in _get_extra_feature_group_performances( + factory=factory, + trainer=trainer, + parse_fn=parse_fn, + extra_groups=extra_groups, + feature_to_type=feature_to_type, + record_count=record_count).items(): + group_feature_performance_results[group_name] = performances + else: + logging.info("Not computing performances for extra groups") + + return {INDIVIDUAL: individual_feature_performance_results, + GROUP: group_feature_performance_results} + + +def _feature_importances_serial_algorithm( + data_dir, trainer, parse_fn, fnames, file_list=None, datarecord_filter_fn=None, factory=None, record_count=99999): + """Serial algorithm for feature importances. This algorithm computes the + importance of each feature. + """ + factory = PermutedInputFnFactory( + data_dir=data_dir, record_count=record_count, file_list=file_list, datarecord_filter_fn=datarecord_filter_fn) + feature_to_type = _get_feature_types_from_records(records=factory.records, fnames=fnames) + + out = {} + for fname, ftype in list(feature_to_type.items()) + [(None, None)]: + logging.info("\n\nComputing importances for {}\n\n".format(fname)) + start = time.time() + fname_ftypes = [(fname, ftype)] if fname is not None else [] + out[str(fname)] = _compute_multiple_permuted_performances_from_trainer( + factory=factory, fname_ftypes=fname_ftypes, + trainer=trainer, parse_fn=parse_fn, record_count=record_count) + logging.info("\n\nImportances computed for {} in {} seconds \n\n".format( + fname, int(time.time() - start))) + # The serial algorithm does not compute group feature results. + return {INDIVIDUAL: out, GROUP: {}} + + +def _process_feature_name_for_mldash(feature_name): + # Using a forward slash in the name causes feature importance writing to fail because strato interprets it as + # part of a url + return feature_name.replace("/", "__") + + +def compute_feature_importances( + trainer, data_dir=None, feature_config=None, algorithm=TREE, parse_fn=None, datarecord_filter_fn=None, **kwargs): + """Perform a feature importance analysis on a trained model + Args: + trainer: (DataRecordTrainer): A DataRecordTrainer object + data_dir: (str): The location of the training or testing data to compute importances over. + If None, the trainer._eval_files are used + feature_config (contrib.FeatureConfig): The feature config object. If this is not provided, it + is taken from the trainer + algorithm (str): The algorithm to use + parse_fn: (function): The parse_fn used by eval_input_fn. By default this is + feature_config.get_parse_fn() + datarecord_filter_fn (function): a function takes a single data sample in com.twitter.ml.api.ttypes.DataRecord format + and return a boolean value, to indicate if this data record should be kept in feature importance module or not. + """ + + # We only use the trainer's eval files if an override data_dir is not provided + if data_dir is None: + logging.info("Using trainer._eval_files (found {} as files)".format(trainer._eval_files)) + file_list = trainer._eval_files + else: + logging.info("data_dir provided. Looking at {} for data.".format(data_dir)) + file_list = None + + feature_config = feature_config or trainer._feature_config + out = {} + if not feature_config: + logging.warn("WARN: Not computing feature importance because trainer._feature_config is None") + out = None + else: + parse_fn = parse_fn if parse_fn is not None else feature_config.get_parse_fn() + fnames = _get_feature_name_from_config(feature_config) + logging.info("Computing importances for {}".format(fnames)) + logging.info("Using the {} feature importance computation algorithm".format(algorithm)) + algorithm = { + SERIAL: _feature_importances_serial_algorithm, + TREE: _feature_importances_tree_algorithm}[algorithm] + out = algorithm(data_dir=data_dir, trainer=trainer, parse_fn=parse_fn, fnames=fnames, file_list=file_list, datarecord_filter_fn=datarecord_filter_fn, **kwargs) + return out + + +def write_feature_importances_to_hdfs( + trainer, feature_importances, output_path=None, metric="roc_auc"): + """Publish a feature importance analysis to hdfs as a tsv + Args: + (see compute_feature_importances for other args) + trainer (Trainer) + feature_importances (dict): Dictionary of feature importances + output_path (str): The remote or local file to write the feature importances to. If not + provided, this is inferred to be the trainer save dir + metric (str): The metric to write to tsv + """ + # String formatting appends (Individual) or (Group) to feature name depending on type + perfs = {"{} ({})".format(k, importance_key) if k != "None" else k: v[metric] + for importance_key, importance_value in feature_importances.items() + for k, v in importance_value.items()} + + output_path = ("{}/feature_importances-{}".format( + trainer._save_dir[:-1] if trainer._save_dir.endswith('/') else trainer._save_dir, + output_path if output_path is not None else str(time.time()))) + + if len(perfs) > 0: + logging.info("Writing feature_importances for {} to hdfs".format(perfs.keys())) + entries = [ + { + "name": name, + "drop": perfs["None"] - perfs[name], + "pdrop": 100 * (perfs["None"] - perfs[name]) / (perfs["None"] + 1e-8), + "perf": perfs[name] + } for name in perfs.keys()] + out = ["Name\tPerformance Drop\tPercent Performance Drop\tPerformance"] + for entry in sorted(entries, key=lambda d: d["drop"]): + out.append("{name}\t{drop}\t{pdrop}%\t{perf}".format(**entry)) + logging.info("\n".join(out)) + write_list_to_hdfs_gfile(out, output_path) + logging.info("Wrote feature feature_importances to {}".format(output_path)) + else: + logging.info("Not writing feature_importances to hdfs") + return output_path + + +def write_feature_importances_to_ml_dash(trainer, feature_importances, feature_config=None): + # type: (DataRecordTrainer, FeatureConfig, dict) -> None + """Publish feature importances + all feature names to ML Metastore + Args: + trainer: (DataRecordTrainer): A DataRecordTrainer object + feature_config (contrib.FeatureConfig): The feature config object. If this is not provided, it + is taken from the trainer + feature_importances (dict, default=None): Dictionary of precomputed feature importances + feature_importance_metric (str, default=None): The metric to write to ML Dashboard + """ + experiment_tracking_path = trainer.experiment_tracker.tracking_path\ + if trainer.experiment_tracker.tracking_path\ + else ExperimentTracker.guess_path(trainer._save_dir) + + logging.info('Computing feature importances for run: {}'.format(experiment_tracking_path)) + + feature_importance_list = [] + for key in feature_importances: + for feature, imps in feature_importances[key].items(): + logging.info('FEATURE NAME: {}'.format(feature)) + feature_name = feature.split(' (').pop(0) + for metric_name, value in imps.items(): + try: + imps[metric_name] = float(value) + logging.info('Wrote feature importance value {} for metric: {}'.format(str(value), metric_name)) + except Exception as ex: + logging.error("Skipping writing metric:{} to ML Metastore due to invalid metric value: {} or value type: {}. Exception: {}".format(metric_name, str(value), type(value), str(ex))) + pass + + feature_importance_list.append(FeatureImportance( + run_id=experiment_tracking_path, + feature_name=_process_feature_name_for_mldash(feature_name), + feature_importance_metrics=imps, + is_group=key == GROUP + )) + +# setting feature config to match the one used in compute_feature_importances + feature_config = feature_config or trainer._feature_config + feature_names = FeatureNames( + run_id=experiment_tracking_path, + names=list(feature_config.features.keys()) + ) + + try: + client = ModelRepoClient() + logging.info('Writing feature importances to ML Metastore') + client.add_feature_importances(feature_importance_list) + logging.info('Writing feature names to ML Metastore') + client.add_feature_names(feature_names) + except (HTTPError, RetryError) as err: + logging.error('Feature importance is not being written due to: ' + 'HTTPError when attempting to write to ML Metastore: \n{}.'.format(err)) diff --git a/twml/twml/contrib/feature_importances/feature_permutation.py b/twml/twml/contrib/feature_importances/feature_permutation.py new file mode 100644 index 000000000..809f5fde0 --- /dev/null +++ b/twml/twml/contrib/feature_importances/feature_permutation.py @@ -0,0 +1,129 @@ +from copy import deepcopy +import random +import types + +from twitter.deepbird.util.thrift.simple_converters import ( + bytes_to_thrift_object, thrift_object_to_bytes) + +from tensorflow.compat.v1 import logging +from com.twitter.ml.api.ttypes import DataRecord # pylint: disable=import-error +import tensorflow.compat.v1 as tf +import twml + + +class PermutedInputFnFactory(object): + + def __init__(self, data_dir, record_count, file_list=None, datarecord_filter_fn=None): + """ + Args: + data_dir (str): The location of the records on hdfs + record_count (int): The number of records to process + file_list (list, default=None): The list of data files on HDFS. If provided, use this instead + of data_dir + datarecord_filter_fn (function): a function takes a single data sample in com.twitter.ml.api.ttypes.DataRecord format + and return a boolean value, to indicate if this data record should be kept in feature importance module or not. + """ + if not (data_dir is None) ^ (file_list is None): + raise ValueError("Exactly one of data_dir and file_list can be provided. Got {} for data_dir and {} for file_list".format( + data_dir, file_list)) + + file_list = file_list if file_list is not None else twml.util.list_files(twml.util.preprocess_path(data_dir)) + _next_batch = twml.input_fns.default_input_fn(file_list, 1, lambda x: x, + num_threads=2, shuffle=True, shuffle_files=True) + self.records = [] + # Validate datarecord_filter_fn + if datarecord_filter_fn is not None and not isinstance(datarecord_filter_fn, types.FunctionType): + raise TypeError("datarecord_filter_fn is not function type") + with tf.Session() as sess: + for i in range(record_count): + try: + record = bytes_to_thrift_object(sess.run(_next_batch)[0], DataRecord) + if datarecord_filter_fn is None or datarecord_filter_fn(record): + self.records.append(record) + except tf.errors.OutOfRangeError: + logging.info("Stopping after reading {} records out of {}".format(i, record_count)) + break + if datarecord_filter_fn: + logging.info("datarecord_filter_fn has been applied; keeping {} records out of {}".format(len(self.records), record_count)) + + def _get_record_generator(self): + return (thrift_object_to_bytes(r) for r in self.records) + + def get_permuted_input_fn(self, batch_size, parse_fn, fname_ftypes): + """Get an input function that passes in a preset number of records that have been feature permuted + Args: + parse_fn (function): The function to parse inputs + fname_ftypes: (list<(str, str)>): The names and types of the features to permute + """ + def permuted_parse_pyfn(bytes_array): + out = [] + for b in bytes_array: + rec = bytes_to_thrift_object(b, DataRecord) + if fname_ftypes: + rec = _permutate_features(rec, fname_ftypes=fname_ftypes, records=self.records) + out.append(thrift_object_to_bytes(rec)) + return [out] + + def permuted_parse_fn(bytes_tensor): + parsed_bytes_tensor = parse_fn(tf.py_func(permuted_parse_pyfn, [bytes_tensor], tf.string)) + return parsed_bytes_tensor + + def input_fn(batch_size=batch_size, parse_fn=parse_fn, factory=self): + return (tf.data.Dataset + .from_generator(self._get_record_generator, tf.string) + .batch(batch_size) + .map(permuted_parse_fn, 4) + .make_one_shot_iterator() + .get_next()) + return input_fn + + +def _permutate_features(rec, fname_ftypes, records): + """Replace a feature value with a value from random selected record + Args: + rec: (datarecord): A datarecord returned from DataRecordGenerator + fname_ftypes: (list<(str, str)>): The names and types of the features to permute + records: (list): The records to sample from + Returns: + The record with the feature permuted + """ + rec_new = deepcopy(rec) + rec_replace = random.choice(records) + + # If the replacement datarecord does not have the feature type entirely, add it in + # to make the logic a bit simpler + for fname, feature_type in fname_ftypes: + fid = twml.feature_id(fname)[0] + if rec_replace.__dict__.get(feature_type, None) is None: + rec_replace.__dict__[feature_type] = ( + dict() if feature_type != 'binaryFeatures' else set()) + if rec_new.__dict__.get(feature_type, None) is None: + rec_new.__dict__[feature_type] = ( + dict() if feature_type != 'binaryFeatures' else set()) + + if feature_type != 'binaryFeatures': + if fid not in rec_replace.__dict__[feature_type] and fid in rec_new.__dict__.get(feature_type, dict()): + # If the replacement datarecord does not contain the feature but the original does + del rec_new.__dict__[feature_type][fid] + elif fid in rec_replace.__dict__[feature_type]: + # If the replacement datarecord does contain the feature + if rec_new.__dict__[feature_type] is None: + rec_new.__dict__[feature_type] = dict() + rec_new.__dict__[feature_type][fid] = rec_replace.__dict__[feature_type][fid] + else: + # If neither datarecord contains this feature + pass + else: + if fid not in rec_replace.__dict__[feature_type] and fid in rec_new.__dict__.get(feature_type, set()): + # If the replacement datarecord does not contain the feature but the original does + rec_new.__dict__[feature_type].remove(fid) + elif fid in rec_replace.__dict__[feature_type]: + # If the replacement datarecord does contain the feature + if rec_new.__dict__[feature_type] is None: + rec_new.__dict__[feature_type] = set() + rec_new.__dict__[feature_type].add(fid) + # If neither datarecord contains this feature + else: + # If neither datarecord contains this feature + pass + return rec_new diff --git a/twml/twml/contrib/feature_importances/helpers.py b/twml/twml/contrib/feature_importances/helpers.py new file mode 100644 index 000000000..f3f600e8b --- /dev/null +++ b/twml/twml/contrib/feature_importances/helpers.py @@ -0,0 +1,96 @@ +import uuid + +from tensorflow.compat.v1 import logging +import twml +import tensorflow.compat.v1 as tf + + +def write_list_to_hdfs_gfile(list_to_write, output_path): + """Use tensorflow gfile to write a list to a location on hdfs""" + locname = "/tmp/{}".format(str(uuid.uuid4())) + with open(locname, "w") as f: + for row in list_to_write: + f.write("%s\n" % row) + tf.io.gfile.copy(locname, output_path, overwrite=False) + + +def decode_str_or_unicode(str_or_unicode): + return str_or_unicode.decode() if hasattr(str_or_unicode, 'decode') else str_or_unicode + + +def longest_common_prefix(strings, split_character): + """ + Args: + string (list): The list of strings to find the longest common prefix of + split_character (str): If not None, require that the return string end in this character or + be the length of the entire string + Returns: + The string corresponding to the longest common prefix + """ + sorted_strings = sorted(strings) + s1, s2 = sorted_strings[0], sorted_strings[-1] + if s1 == s2: + # If the strings are the same, just return the full string + out = s1 + else: + # If the strings are not the same, return the longest common prefix optionally ending in split_character + ix = 0 + for i in range(min(len(s1), len(s2))): + if s1[i] != s2[i]: + break + if split_character is None or s1[i] == split_character: + ix = i + 1 + out = s1[:ix] + return out + + +def _expand_prefix(fname, prefix, split_character): + if len(fname) == len(prefix): + # If the prefix is already the full feature, just take the feature name + out = fname + elif split_character is None: + # Advance the prefix by one character + out = fname[:len(prefix) + 1] + else: + # Advance the prefix to the next instance of split_character or the end of the string + for ix in range(len(prefix), len(fname)): + if fname[ix] == split_character: + break + out = fname[:ix + 1] + return out + + +def _get_feature_types_from_records(records, fnames): + # This method gets the types of the features in fnames by looking at the datarecords themselves. + # The reason why we do this rather than extract the feature types from the feature_config is + # that the feature naming conventions in the feature_config are different from those in the + # datarecords. + fids = [twml.feature_id(fname)[0] for fname in fnames] + feature_to_type = {} + for record in records: + for feature_type, values in record.__dict__.items(): + if values is not None: + included_ids = set(values) + for fname, fid in zip(fnames, fids): + if fid in included_ids: + feature_to_type[fname] = feature_type + return feature_to_type + + +def _get_metrics_hook(trainer): + def get_metrics_fn(trainer=trainer): + return {k: v[0]for k, v in trainer.current_estimator_spec.eval_metric_ops.items()} + return twml.hooks.GetMetricsHook(get_metrics_fn=get_metrics_fn) + + +def _get_feature_name_from_config(feature_config): + """Extract the names of the features on a feature config object + """ + decoded_feature_names = [] + for f in feature_config.get_feature_spec()['features'].values(): + try: + fname = decode_str_or_unicode(f['featureName']) + except UnicodeEncodeError as e: + logging.error("Encountered decoding exception when decoding %s: %s" % (f, e)) + decoded_feature_names.append(fname) + return decoded_feature_names diff --git a/twml/twml/contrib/hooks.py b/twml/twml/contrib/hooks.py new file mode 100644 index 000000000..6d68831fc --- /dev/null +++ b/twml/twml/contrib/hooks.py @@ -0,0 +1,42 @@ +import datetime + +from absl import logging +import pytz +import tensorflow.compat.v1 as tf + + +class StopAtTimeHook(tf.train.SessionRunHook): + """ + Hook that stops training at a fixed datetime + """ + + def __init__(self, stop_time): + """ + Arguments: + stop_time: + a datetime.datetime or a datetime.timedelta specifying when to stop. + For naive datetime.datetime objects (with no time zone specified), + UTC time zone is assumed. + """ + if isinstance(stop_time, datetime.timedelta): + self._stop_datetime = pytz.utc.localize(datetime.datetime.utcnow() + stop_time) + elif isinstance(stop_time, datetime.datetime): + if stop_time.tzinfo is None: + self._stop_datetime = pytz.utc.localize(stop_time) + else: + self._stop_datetime = stop_time.astimezone(pytz.UTC) + else: + raise ValueError("Expecting datetime or timedelta for stop_time arg") + self._stop_requested = False + + def after_run(self, run_context, run_values): + delta = self._stop_datetime - pytz.utc.localize(datetime.datetime.utcnow()) + if delta.total_seconds() <= 0: + logging.info("StopAtTimeHook reached stop_time; requesting stop") + run_context.request_stop() + self._stop_requested = True + + @property + def stop_requested(self): + """ true if this hook requested a stop """ + return self._stop_requested diff --git a/twml/twml/contrib/initializers.py b/twml/twml/contrib/initializers.py new file mode 100644 index 000000000..52bad3a19 --- /dev/null +++ b/twml/twml/contrib/initializers.py @@ -0,0 +1,61 @@ +import numpy as np +import tensorflow.compat.v1 as tf + + +TWML_INIT_FEED_KEY = "TWML_INIT_FEED_COLLECTION" + + +class PartitionConstant(tf.keras.initializers.Constant): + """A constant initializer that supports partitions""" + + def __call__(self, shape, dtype=None, partition_info=None): + if partition_info is not None: + if not isinstance(self.value, np.ndarray): + raise ValueError( + "Currently, PartitionConstant only supports " + "partitioning on np.ndarrays. Got {}".format(type(self.value).__name__)) + offsets = partition_info.var_offset + indices = tuple([slice(offset, offset + size) for offset, size in zip(offsets, shape)]) + subset = self.value[indices] + return subset + else: + return self.value + + +partition_constant_initializer = PartitionConstant + + +class PlaceholderInitializer(tf.keras.initializers.Initializer): + """A placeholder initializer that supports partitions""" + + def __init__(self, shape, dtype): + self.dtype = dtype + self.value = tf.placeholder(dtype=dtype, shape=shape) + + def __call__(self, shape, dtype=None, partition_info=None): + if partition_info is not None: + if self.dtype != dtype: + raise ValueError("dtype does not match placeholder dtype") + offsets = partition_info.var_offset + indices = tuple([slice(offset, offset + size) for offset, size in zip(offsets, shape)]) + subset = self.value[indices] + return subset + else: + return self.value + + +def get_init_feed_dict(): + """Get the init feed dictionary to be used when running the init op.""" + # Get the reference to the collection. + init_feed_collection = tf.get_collection(TWML_INIT_FEED_KEY) + init_feed_dict = {} + for d in init_feed_collection: + init_feed_dict.update(d) + return init_feed_dict + + +def clear_init_feed_collection(): + """Clear the init feed collection.""" + init_feed_collection = tf.get_collection_ref(TWML_INIT_FEED_KEY) + while init_feed_collection: + init_feed_collection.pop() diff --git a/twml/twml/contrib/layers/__init__.py b/twml/twml/contrib/layers/__init__.py new file mode 100644 index 000000000..aa6e7d7e4 --- /dev/null +++ b/twml/twml/contrib/layers/__init__.py @@ -0,0 +1,11 @@ +# pylint: disable=wildcard-import +""" This module contains all contrib Layers. """ + +from .hashed_percentile_discretizer import HashedPercentileDiscretizer # noqa: F401 +from .hashing_discretizer import HashingDiscretizer # noqa: F401 +from .mask_layer import MaskLayer # noqa: F401 +from .embedding_lookup import EmbeddingLookup # noqa: F401 +from .factorization_machine import FactorizationMachine # noqa: F401 +from .full_dense import full_dense, FullDense # noqa: F401 +from .stacked_rnn import StackedRNN, stacked_rnn # noqa: F401 +from .zscore_normalization import ZscoreNormalization, zscore_normalization # noqa: F401 diff --git a/twml/twml/contrib/layers/embedding_lookup.py b/twml/twml/contrib/layers/embedding_lookup.py new file mode 100644 index 000000000..c83dc7edd --- /dev/null +++ b/twml/twml/contrib/layers/embedding_lookup.py @@ -0,0 +1,419 @@ +import os +import re +import time + +from collections import OrderedDict + +from absl import logging +import numpy as np +import tensorflow.compat.v1 as tf +from tensorflow.python.ops.lookup_ops import index_table_from_tensor + +import twml + +# Padding is 0, UNK is 1: +PAD_WORD_ID = 0 +OOV_WORD_ID = 1 + + +def load_initializers_from_csv( + embedding_path, vocab_size=-1, embedding_size=None, separator=None, vocab=None +): + """ + Loads embeddings saved in the `glove format `_. + The glove format is a txt file separated by spaces. + Each line looks like: "word 0.00001 0.2334 ...". + + Arguments: + embedding_path: + path to the embeddings file on HDFS (hdfs://default/...) + or its local_path (/path/to/...). + The embedding_path may also specify a pattern. In which case, the embeddings + are read in the lexical order of the filenames that match the order. + vocab_size: + the maximum size of the vocabulary. The top ``vocab_size`` words in the file + are included in the vocabulary. If you specify a positive vocab_size, + the words are expected to be in descending order of frequency. + This allows the embeddings to be easily filtered to top vocab_size words. + Reducing the vocab_size acts as a regularizer, preventing the model to overfit on rarer words. + A negative vocab_size loads all embeddings. + Reducing the vocab_size may also help with memory issues, + allowing the embedding initializers to fit inside the graph. + embedding_size: + Defaults to None. If None, the embedding size is infered from the file name. + For example, ``glove.300d.txt`` and ``glove300d200.txt`` will both infrered + as ``embedding_size=300``. If this can't be done, the ``embedding_size`` is + inferred from the first line in the file. If ``embedding_size`` is provided, + only the last ``embedding_size`` values of each line are considered. This + allows the line parser to recover from partial word parsing errors. + separator: + Specifies the separator to use when splitting each line into values. + Default value is a whitespace (same as glove format). + vocab: + OrderedDict mapping words to np.array embedding vectors. Initializes the vocabulary. + Duplicate words found in the file are ignored. + Defaults to a vocabulary of two words:: + + vocab = OrderedDict() + vocab[''] = np.random.randn(embedding_size) + vocab[''] = np.random.randn(embedding_size) + + Returns: + tuple of (vocab_initializer, weight_initializer, shape) + + vocab_initializer: + A tf.constant_initializer containing a vector of word strings of size vocab_size. + weight_initializer: + A twml.contrib.initializers.partition_constant_initializer containing + the weight matrix of embeddings of size vocab_size x embedding_size. + shape: + A tuple containing of (vocab_size, embedding_size). + + """ + + start = time.time() + + embedding_path = twml.util.sanitize_hdfs_path(embedding_path) + + is_user_vocab = True + if vocab is None: + vocab = OrderedDict() + vocab[''] = True + vocab[''] = True + is_user_vocab = False + elif not isinstance(vocab, OrderedDict): + raise RuntimeError( + "Expecting vocab argument of type OrderedDict or None. " + "Got type %s instead." % type(vocab).__name__ + ) + + if embedding_size is None: + embedding_file = os.path.basename(embedding_path) + match = re.search(r"[^\d]([\d]+)d", embedding_file) + if match is not None: + embedding_size = int(match.group(1)) + + if embedding_size is not None and not isinstance(embedding_size, int): + raise RuntimeError( + "Expecting embedding_size argument of type int or None. " + "Got type %s, instead." % type(embedding_size).__name__ + ) + + embedding_paths = sorted(tf.io.gfile.glob(embedding_path)) + + if len(embedding_paths) > 1: + raise ValueError( + "You are most likely using a the wrong --embedding.path" + ) + + embedding_path = embedding_paths[0] + logging.info("Reading embeddings file from path %s.." % embedding_path) + + with tf.io.gfile.GFile(embedding_path) as f: + lines = f.readlines() + + logging.info("Done reading embeddings file from path %s." % embedding_path) + + logging.info("Parsing vocbulary and embeddings...") + + for line in lines: + # Word and weights separated by space + values = line.strip().split(separator) + # Word is first symbol on each line + word = values[0] + + if word not in vocab: + if embedding_size is None or embedding_size <= 0: + # get all elements after the first one. + word_weights = values[1:] + embedding_size = len(word_weights) + else: + # get the last embedding_size elements + word_weights = values[-min(embedding_size, len(values) - 1) :] + + try: + if len(word_weights) != embedding_size: + raise ValueError + + word_weights = np.asarray(word_weights, dtype=np.float32) + vocab[word] = word_weights + except ValueError: + logging.info("Wasn't able to load embeddings for word '%s'. Ignoring it" % word) + + vocab_len = len(vocab) + if vocab_size > 0 and vocab_len == vocab_size: + # Limit vocabulary to top terms + break + elif (vocab_len % 1000) == 0: + logging.info("Loaded %d words into vocab" % vocab_len) + + else: + logging.info("found duplicate word: %s" % word) + + if not is_user_vocab: + vocab[''] = np.random.randn(embedding_size) + vocab[''] = np.random.randn(embedding_size) + + words = list(vocab.keys()) + weights = list(vocab.values()) + + weights = np.asarray(weights, dtype=np.float32) + assert weights.shape[0] == len(vocab) + assert weights.shape[1] == embedding_size + + vocab_initializer = tf.constant_initializer(words, tf.string) + weight_initializer = twml.contrib.initializers.PartitionConstant(weights, tf.float32) + + logging.info("Loaded %d embeddings in %d seconds." % (len(vocab), time.time() - start)) + return vocab_initializer, weight_initializer, weights.shape + + +def add_parser_arguments(parser): + """ + Adds the embedding.path and embedding.vocab_size command-line arguments to the parser. + These can be used to call an initializer loader function like + the ``load_initializers_from_csv`` function. + + Arguments: + parser: argparse.ArgumentParser instance obtained from Trainer.get_trainer_parser + + Returns: + argparse.ArgumentParser instance with discretizer-specific arguments added + """ + + parser.add_argument( + "--embedding.path", + "--embedding_path", + dest="embedding_path", + type=str, + default=None, + help="When specified, loads glove embeddings from .txt glove file", + ) + parser.add_argument( + "--embedding.vocab_size", + "--embedding_vocab_size", + dest="embedding_vocab_size", + type=int, + default=-1, + help="Size of vocabulary. Uses this many of the most frequent terms. Defaults to -1 (use full vocab).", + ) + + return parser + + +class EmbeddingLookup(twml.layers.Layer): + """Layer for looking up embeddings. + Transforms a sequence of strings to a sequence of embeddings. + + Arguments: + vocab_size: + The number of word strings and embeddings in the vocabulary. + output_size: + Long or Integer, dimensionality of the output space. The embedding vector size. + vocab_initializer: + Initializer function for the vocabulary. Required. The initializer should + return a list of strings of size vocab_size. + weight_initializer: + Initializer function for the weight matrix of size vocab_size x output_size. + This argument defaults to zeros_initializer(). + This is valid when the EmbeddingLookup is the first layer of + parameters but should be changed otherwise. + trainable: + Boolean, if `True` adds variables to the graph collection + ``GraphKeys.TRAINABLE_VARIABLES`` (see `tf.Variable + `_). + Defaults to True: trains the embeddings. + num_oov_buckets: + The number of buckets to use for OOV strings. These bucket ids occur after the vocab bucket + ids. Hashing is used to assign OOV strings to these buckets. If `num_oov_buckets` is not + specified, index `OOV_WORD_ID` is used for OOV strings. + name: + String, the name of the layer. Layers with the same name will + share weights, but to avoid mistakes we require ``reuse=True`` in such cases. + num_partitions: + Number of partitions to use for the weight variable. Defaults to 1. + partition_axis: + If num_partitions is specified, the partition axis for the weight variable + Defaults to 0 (partition by row). + Must be 0 (row) or 1 (column, does not support yet) + weight_regularizer: + Regularizer function for the weight matrix. + Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect. + dtype: + Defaults to tf.float32. Specifies the dtype of the weights. + use_placeholder: + Defaults to True. + If set to `True`, the initializer is passed via a placeholder. The initializer in this case needs to be of type `keras.initializers.Constant`. + If set to `False`, the initializer becomes part of the graph. This can sometimes be beyond what protobuf clients support. + checkpoint_dir: + Default to None. + If set to the path of a checkpoint, load embedding from the checkpoint. + convert_to_lowercase: + Default to True. + Converting all string inputs to lowercase. + + Notes: If `use_placeholder` is set to `True`, the feed dictionary can be accessed by calling `twml.contrib.initializers.get_init_feed_dict()`. + """ + + def __init__( + self, + vocab_size, + output_size, + vocab_initializer, + weight_initializer=None, + trainable=True, + num_oov_buckets=None, + oov_word_id=None, + name=None, + num_partitions=1, + partition_axis=0, + weight_regularizer=None, + dtype=None, + use_placeholder=True, + checkpoint_dir=None, + convert_to_lowercase=True, + **kwargs, + ): + if dtype is None: + # prevents a bug where the parent class defaults to the type of the first input tensor. + dtype = tf.float32 + super().__init__(trainable=trainable, name=name, dtype=dtype, **kwargs) + # Weights initialization is set to 0s. This is safe for full sparse layers because + # you are supposed to learn your embedding from the label. + + is_constant_init = isinstance(weight_initializer, tf.keras.initializers.Constant) + if use_placeholder and (not is_constant_init) and (weight_initializer is not None): + raise ValueError("Weight initializer should be a `Constant` or `None`.") + + if weight_initializer is None: + self.weight_initializer = tf.zeros_initializer() + else: + self.weight_initializer = weight_initializer + self.use_placeholder = use_placeholder + self.checkpoint_dir = checkpoint_dir + self.convert_to_lowercase = convert_to_lowercase + + self.vocab_initializer = vocab_initializer + self.vocab_size = vocab_size + self.output_size = output_size + self.num_partitions = num_partitions + self.partition_axis = partition_axis + self.weight_regularizer = weight_regularizer + self.trainable = trainable + self.oov_word_id = oov_word_id + self.num_oov_buckets = num_oov_buckets + + if self.oov_word_id is not None and self.num_oov_buckets is not None: + raise ValueError("At most one of oov_word_id or num_oov_buckets should be specified") + elif self.oov_word_id is None and self.num_oov_buckets is None: + self.oov_word_id = OOV_WORD_ID # use the default OOV word id + + if partition_axis != 0: + raise NotImplementedError("embedding_lookup only supports partition_axis = 0") + + def build(self, input_shapes): + """ + creates the ``vocab`` and ``weight`` Variables + of shape ``[vocab_size]`` and ``[vocab_size, output_size]`` respectively. + """ + partitioner = None + + additional_buckets_for_oov = self.num_oov_buckets if self.num_oov_buckets is not None else 0 + shape = [self.vocab_size + additional_buckets_for_oov, self.output_size] + + if self.use_placeholder: + embedding_weight_initializer = twml.contrib.initializers.PlaceholderInitializer( + shape, self.dtype + ) + tf.add_to_collection( + twml.contrib.initializers.TWML_INIT_FEED_KEY, + {embedding_weight_initializer.value: self.weight_initializer.value}, + ) + else: + embedding_weight_initializer = self.weight_initializer + + if self.num_partitions: + partition_axis = int(self.partition_axis) + partitioner = tf.fixed_size_partitioner(self.num_partitions, axis=partition_axis) + else: + # Regular variables do not like it when you pass both constant tensors and shape + if not callable(self.weight_initializer): + shape = None + + self.vocab = self.add_variable( + 'vocab', + initializer=self.vocab_initializer, + shape=[self.vocab_size], + dtype=tf.string, + trainable=False, + ) + + self.weight = self.add_variable( + 'weight', + initializer=None if self.checkpoint_dir is not None else embedding_weight_initializer, + regularizer=self.weight_regularizer, + shape=shape, + dtype=self.dtype, + trainable=self.trainable, + partitioner=partitioner, + ) + if self.checkpoint_dir is not None: + twml.trainers.trainer.init_from_checkpoint(self.checkpoint_dir, {'weight': self.weight.name}) + + self.built = True + + def call( + self, inputs, debug=False, oov_summaries=False, **kwargs + ): # pylint: disable=unused-argument + """Converts word strings to word ids using the vocabulary lookup table. + Then converts the word ids to their commensurate embedding vector. + + Arguments: + inputs: + A tensor of word strings. Typically, of size batch_size x seq_len. + debug: + When True, prints the input strings and their commensurate input_ids. + Defaults to False. + oov_summaries: + When True, log the out-of-vocabulary (OOV) rate to TensorBoard + Defaults to False. + + Returns: + The mapping of input word strings to output embedding vectors. + Given an input of shape ``batch_size x seq_len``, the output has shape + ``batch_size x seq_len x embedding_size``. + """ + if self.convert_to_lowercase: + inputs = tf.strings.lower(inputs) + if self.num_oov_buckets is None: + lookup_table = index_table_from_tensor(self.vocab, default_value=self.oov_word_id) + else: + lookup_table = index_table_from_tensor(self.vocab, num_oov_buckets=self.num_oov_buckets) + input_ids = lookup_table.lookup(inputs) + + if oov_summaries: + oov_count = tf.reduce_sum( + tf.cast(tf.math.equal(input_ids, self.oov_word_id), tf.dtypes.float32) + ) + valid_count = tf.reduce_sum( + tf.cast(tf.math.not_equal(input_ids, PAD_WORD_ID), tf.dtypes.float32) + ) + oov_rate = oov_count / valid_count + tf.summary.scalar('OOV_rate', oov_rate) + + if debug: + + def print_debug(): + return tf.print("input_strings:", inputs, "\ninput_ids: ", input_ids, summarize=140) + + with tf.control_dependencies([twml.util.do_every_n_steps(print_debug, 1000)]): + input_ids = tf.identity(input_ids) + + output_embeddings = tf.nn.embedding_lookup( + params=self.weight, ids=input_ids, partition_strategy='div' + ) + + output_shape = inputs.shape.concatenate(tf.TensorShape([self.output_size])) + output_embeddings.set_shape(output_shape) + + return output_embeddings diff --git a/twml/twml/contrib/layers/factorization_machine.py b/twml/twml/contrib/layers/factorization_machine.py new file mode 100644 index 000000000..3b8adae42 --- /dev/null +++ b/twml/twml/contrib/layers/factorization_machine.py @@ -0,0 +1,179 @@ +# pylint: disable=no-member, arguments-differ, attribute-defined-outside-init, unused-argument +""" +Implementing factorization Layer +""" + +from twitter.deepbird.sparse.sparse_ops import _pad_empty_outputs + +import tensorflow.compat.v1 as tf +import twml +from twml.layers.layer import Layer + + +class FactorizationMachine(Layer): + """factorization machine layer class. + This layer implements the factorization machine operation. + The paper is "Factorization Machines" by Steffen Rendle. + TDD: go/tf-fm-tdd + + Arguments: + num_latent_variables: + num of latent variables + The number of parameter in this layer is num_latent_variables x n where n is number of + input features. + weight_initializer: + Initializer function for the weight matrix. + This argument defaults to zeros_initializer(). + This is valid when the FullSparse is the first layer of + parameters but should be changed otherwise. + weight_regularizer: + Regularizer function for the weight matrix. + Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect. + activation: + Activation function (callable). Set it to None to maintain a linear activation. + trainable: + Boolean, if `True` also add variables to the graph collection + ``GraphKeys.TRAINABLE_VARIABLES`` (see `tf.Variable + `_). + name: + String, the name of the layer. Layers with the same name will + share weights, but to avoid mistakes we require ``reuse=True`` in such cases. + use_sparse_grads: + Boolean, if `True` do sparse mat mul with `embedding_lookup_sparse`, which will + make gradients to weight matrix also sparse in backward pass. This can lead to non-trivial + speed up at training time when input_size is large and optimizer handles sparse gradients + correctly (eg. with SGD or LazyAdamOptimizer). If weight matrix is small, it's recommended + to set this flag to `False`; for most use cases of FullSparse, however, weight matrix will + be large, so it's better to set it to `True` + use_binary_values: + Assume all non zero values are 1. Defaults to False. + This can improve training if used in conjunction with MDL. + This parameter can also be a list of binary values if `inputs` passed to `call` a list. + """ + + def __init__(self, + num_latent_variables=10, + weight_initializer=None, + activation=None, + trainable=True, + name=None, + use_sparse_grads=True, + use_binary_values=False, + weight_regularizer=None, + substract_self_cross=True, + **kwargs): + super(FactorizationMachine, self).__init__(trainable=trainable, name=name, **kwargs) + + if weight_initializer is None: + weight_initializer = tf.zeros_initializer() + self.weight_initializer = weight_initializer + self.num_latent_variables = num_latent_variables + self.activation = activation + self.use_sparse_grads = use_sparse_grads + self.use_binary_values = use_binary_values + self.weight_regularizer = weight_regularizer + self.substract_self_cross = substract_self_cross + + def build(self, input_shape): + """ + creates``weight`` Variable of shape``[input_size, num_latent_variables]``. + + """ + + shape = [input_shape[1], self.num_latent_variables] + + # There is a 2GB limitation for each tensor because of protobuf. + # 2**30 is 1GB. 2 * (2**30) is 2GB. + dtype = tf.as_dtype(self.dtype) + requested_size = input_shape[1] * self.num_latent_variables * dtype.size + if (requested_size >= 2**31): + raise ValueError("Weight tensor can not be larger than 2GB. " % + "Requested Dimensions(%d, %d) of type %s (%d bytes total)" + (input_shape[1], self.num_latent_variables, dtype.name)) + + if not callable(self.weight_initializer): + shape = None + + # dense tensor + self.weight = self.add_variable( + 'weight', + initializer=self.weight_initializer, + regularizer=self.weight_regularizer, + shape=shape, + dtype=self.dtype, + trainable=True, + ) + + self.built = True + + def compute_output_shape(self, input_shape): + """Computes the output shape of the layer given the input shape. + + Args: + input_shape: A (possibly nested tuple of) `TensorShape`. It need not + be fully defined (e.g. the batch size may be unknown). + + Raises NotImplementedError. + + """ + raise NotImplementedError + + def call(self, inputs, **kwargs): # pylint: disable=unused-argument + """The logic of the layer lives here. + + Arguments: + inputs: + A SparseTensor + Returns: + - If `inputs` is `SparseTensor`, then returns a number with cross info + """ + # The following are given: + # - inputs is a sparse tensor, we call it sp_x. + # - The dense_v tensor is a dense matrix, whose row i + # corresponds to the vector V_i. + # weights has shape [num_features, k] + sp_x = inputs + if isinstance(inputs, twml.SparseTensor): + sp_x = inputs.to_tf() + elif not isinstance(sp_x, tf.SparseTensor): + raise TypeError("The sp_x must be of type tf.SparseTensor or twml.SparseTensor") + + indices = sp_x.indices[:, 1] + batch_ids = sp_x.indices[:, 0] + values = tf.reshape(sp_x.values, [-1, 1], name=self.name) + if self.use_sparse_grads: + v = tf.nn.embedding_lookup(self.weight, indices) + # if (self.use_binary_values): + # values = tf.ones(tf.shape(values), dtype=values.dtype) + v_times_x = v * values + # First term: Sum_k [Sum_i (v_ik * x_i)]^2 + all_crosses = tf.segment_sum(v_times_x, batch_ids, name=self.name) + all_crosses_squared = tf.reduce_sum((all_crosses * all_crosses), 1) + + if self.substract_self_cross: + # Second term: Sum_k Sum_i [ (v_ik * x_i)^2 ] + v_times_x_2 = v_times_x**2 + self_crosses = tf.reduce_sum(tf.segment_sum(v_times_x_2, batch_ids, name=self.name), 1) + outputs = all_crosses_squared - self_crosses + else: + outputs = all_crosses_squared + else: + # need to check if prediction is faster with code below + crossTerm = tf.reduce_sum((tf.sparse_tensor_dense_matmul(sp_x, self.weight)**2), 1) + + if self.substract_self_cross: + # compute self-cross term + self_crossTerm = tf.reduce_sum(tf.segment_sum((tf.gather(self.weight, indices) * values)**2, batch_ids), 1) + outputs = crossTerm - self_crossTerm + else: + outputs = crossTerm + + if self.activation is not None: + outputs = self.activation(outputs) + + outputs = tf.reshape(outputs, [-1, 1], name=self.name) + outputs = _pad_empty_outputs(outputs, tf.cast(sp_x.dense_shape[0], tf.int32)) + # set more explicit and static shape to avoid shape inference error + # valueError: The last dimension of the inputs to `Dense` should be defined. Found `None` + outputs.set_shape([None, 1]) + return outputs diff --git a/twml/twml/contrib/layers/full_dense.py b/twml/twml/contrib/layers/full_dense.py new file mode 100644 index 000000000..ad78a91a4 --- /dev/null +++ b/twml/twml/contrib/layers/full_dense.py @@ -0,0 +1,380 @@ +# pylint: disable=no-member,arguments-differ, attribute-defined-outside-init +""" +Implementing Full Dense Layer +""" +from twml.layers import Layer + +import tensorflow.compat.v1 as tf +from tensorflow.python.layers import core + + +class FullDense(Layer): + """ + Full-connected, Dense input layer class. + This layer implements the operation: + + .. code-block:: python + + outputs = activation(inputs.weight + bias) + + Where ``activation`` is the activation function passed as the ``activation`` + argument (if not ``None``), ``weight`` is a weights matrix created by the layer, + and ``bias`` is a bias vector created by the layer. + + However, this layer breaks up ``weight`` into ``num_partitions`` parts, + for the purpose of even disribution of weights across parameter servers + for distributed training. + + Note - This layer is created to allow distributed training optimizations, + but can also be used for single node training (e.g. hogwild) without + code modification + + Arguments: + output_size: + Integer or Long, dimensionality of the output space. + weight_initializer: + Initializer function for the weight matrix. + weight_regularizer: + Regularizer function for the weight matrix. + Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect. + weight_constraint: + An optional projection function to be applied to the + weight after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + bias_constraint: + An optional projection function to be applied to the + bias after being updated by an `Optimizer`. + num_partitions: + Number of pieces to partition the weights into. This layer does + column partitioning of the weights, which is equivalent to + processing the input tensor with multiple fully connected layers + of smaller output size, and then concatenating these outputs + activation: + Activation function (callable). Set it to None to maintain a linear activation. + use_bias: + Boolean whether to include a bias parameter in the layer + bias_initializer: + Initializer function for the bias. + bias_regularizer: + Regularizer function for the bias. + Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect. + activity_regularizer: + Regularizer function for the output. + trainable: + Boolean, if `True` also add variables to the graph collection + ``GraphKeys.TRAINABLE_VARIABLES`` (see `tf.Variable + `_). + name: + String, the name of the layer. Layers with the same name will + share weights, but to avoid mistakes we require ``reuse=True`` in such cases. + + Properties: + output_size: + Python integer, dimensionality of the output space. + activation: + Activation function (callable). + weight_initializer: + Initializer instance (or name) for the weight matrix. + bias_initializer: + Initializer instance (or name) for the bias. + weights: + list of underlying weight and bias matrix components. no guarantee on order of elements + weight_regularizer: + Regularizer instance for the weight matrix (callable) + bias_regularizer: + Regularizer instance for the bias (callable). + activity_regularizer: + Regularizer instance for the output (callable) + weight_constraint: + Constraint function for the weight matrix. + bias_constraint: + Constraint function for the bias. + """ + + def __init__(self, output_size, + weight_initializer=None, + weight_regularizer=None, + weight_constraint=None, + bias_constraint=None, + num_partitions=3, + activation=None, + use_bias=True, + bias_initializer=tf.zeros_initializer(), + bias_regularizer=None, + activity_regularizer=None, + trainable=True, + name=None, + **kwargs): + super(FullDense, self).__init__(trainable=trainable, name=name, **kwargs) + self._output_sizes = self._get_output_partition_sizes(output_size, num_partitions) + self._units = output_size + self._activation = activation + self._weight_initializer = weight_initializer + self._bias_initializer = bias_initializer + self._weight_regularizer = weight_regularizer + self._bias_regularizer = bias_regularizer + self._weight_constraint = weight_constraint + self._bias_constraint = bias_constraint + self._use_bias = use_bias + # NOTE - many initializers depend on fan_in and fan_out + # - as such, initialization here may be different than + # - for a non-partitioned FullDense + self._parts = [core.Dense(units=out_size, + activation=activation, + use_bias=use_bias, + kernel_initializer=weight_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=weight_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=weight_constraint, + bias_constraint=bias_constraint, + trainable=trainable, + name=name, + **kwargs) for out_size in self._output_sizes] + + @staticmethod + def _get_output_partition_sizes(out_size, num_parts): + """ Returns the appropriate output sizes of the partitions """ + boundaries = [out_size * n // num_parts for n in range(num_parts + 1)] + return [k - j for j, k in zip(boundaries[:], boundaries[1:])] + + def build(self, input_shapes): + """ Create the appropriately sized weights and biases in each layer partition """ + if isinstance(input_shapes, (list, tuple)): + input_shape = input_shapes[0] + is_compatible = True + for other_shape in input_shapes[1:]: + is_compatible &= input_shape.is_compatible_with(other_shape) + if not is_compatible: + raise ValueError("Input shapes %s are not compatible." % input_shapes) + else: + input_shape = input_shapes + + for part in self._parts: + part.build(input_shape) + + self.built = True + + @property + def units(self): + """ Returns the number of output units of the layer """ + return self._units + + @property + def output_size(self): + """ Returns the number of output units of the layer """ + return self._units + + @property + def activation(self): + """ Returns the activation function """ + return self._activation + + @property + def weight_initializer(self): + """ Returns the weight_initializer """ + return self._weight_initializer + + @property + def weight_regularizer(self): + """ Returns the weight_regularizer """ + return self._weight_regularizer + + @property + def weight_constraint(self): + """ Returns the weight_constraint """ + return self._weight_constraint + + @property + def bias_initializer(self): + """ Returns the bias_initializer """ + return self._bias_initializer + + @property + def bias_regularizer(self): + """ Returns the bias_regularizer """ + return self._bias_regularizer + + @property + def bias_constraint(self): + """ Returns the bias_constraint """ + return self._bias_constraint + + @property + def use_bias(self): + """ Returns whether a bias is used in the layer """ + return self._use_bias + + @property + def trainable_variables(self): + """ Returns the trainable variables of the layer """ + trainable_vars = [] + for pt in self._parts: + trainable_vars += pt.trainable_variables + return trainable_vars + + @property + def trainable_weights(self): + """ Returns the trainable variables of the layer """ + return self.trainable_variables + + @property + def non_trainable_variables(self): + """ Returns the non-trainable variables of the layer """ + non_trainable_vars = [] + for pt in self._parts: + non_trainable_vars += pt.non_trainable_variables + return non_trainable_vars + + @property + def non_trainable_weights(self): + """ Returns the non-trainable variables of the layer """ + return self.non_trainable_variables + + @property + def variables(self): + """ Returns a list of all weights and biases in this layer """ + layer_vars = [] + for pt in self._parts: + layer_vars += pt.weights + return layer_vars + + @property + def weights(self): + """ Returns a list of all weights and biases in this layer """ + return self.variables + + @property + def dtype(self): + """ Returns the dtype of the layers weights """ + return self._parts[0].dtype + + def call(self, inputs, **kwargs): # pylint: disable=unused-argument + """The logic of the layer lives here. + + Arguments: + inputs: + A dense Tensor or a list of such. + If `inputs` is a list, all tensors must have same `dense_shape`. + + Returns: + - If `inputs` is `SparseTensor`, then returns `bias + inputs * dense_b`. + - If `inputs` is a `list[SparseTensor`, then returns + `bias + accumulate_n([sp_a * dense_b for sp_a in inputs])`. + """ + if not isinstance(inputs, (list, tuple)): + inputs = [inputs] + + outputs = [] + for inp in inputs: + part_outputs = [part(inp) for part in self._parts] + outputs.append(tf.concat(part_outputs, axis=-1)) + + return tf.accumulate_n(outputs) + + +def full_dense(inputs, output_size, + weight_initializer=None, + weight_regularizer=None, + weight_constraint=None, + bias_constraint=None, + num_partitions=3, + activation=None, + use_bias=True, + bias_initializer=tf.zeros_initializer(), + bias_regularizer=None, + activity_regularizer=None, + trainable=True, + name=None, + reuse=None, + **kwargs): + """Functional interface for the fully-connected dense-input layer. + This layer implements the operation: + `outputs = activation(inputs.weight + bias)` + Where `activation` is the activation function passed as the `activation` + argument (if not `None`), `weight` is a weights matrix created by the layer, + and `bias` is a bias vector created by the layer + (only if `use_bias` is `True`). + + However, this layer breaks up ``weight`` into ``num_partitions`` parts, + for the purpose of even disribution of weights across parameter servers + for distributed training. + + Note - This layer is created to allow distributed training optimizations, + but can also be used for single node training (e.g. hogwild) without + code modification + + Arguments: + inputs: Tensor input. + output_size: Integer or Long, dimensionality of the output space. + weight_initializer: Initializer function for the weight matrix. + If `None` (default), weights are initialized using the default + initializer used by `tf.get_variable`. + weight_regularizer: + Regularizer function for the weight matrix. + Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect. + weight_constraint: + An optional projection function to be applied to the + weight after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + bias_constraint: + An optional projection function to be applied to the + bias after being updated by an `Optimizer`. + num_partitions: + Number of pieces to partition the weights into. This layer does + column partitioning of the weights, which is equivalent to + processing the input tensor with multiple fully connected layers + of smaller output size, and then concatenating these outputs + activation: Activation function (callable). Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + bias_initializer: + Initializer function for the bias. + bias_regularizer: + Regularizer function for the bias. + Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect. + activity_regularizer: + Regularizer function for the output. + trainable: + Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: + String, the name of the layer. + reuse: + Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor with shape `inputs.shape[:-1] + [output_size]`. + """ + if not isinstance(inputs, (list, tuple)): + inputs = [inputs] + + dtype = inputs[0].dtype.base_dtype + + layer = FullDense(output_size=output_size, + weight_initializer=weight_initializer, + weight_regularizer=weight_regularizer, + weight_constraint=weight_constraint, + bias_constraint=bias_constraint, + num_partitions=num_partitions, + activation=activation, + use_bias=use_bias, + bias_initializer=bias_initializer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + trainable=trainable, + name=name, + dtype=dtype, + _scope=name, + _reuse=reuse, + **kwargs) + + return layer(inputs) diff --git a/twml/twml/contrib/layers/hashed_percentile_discretizer.py b/twml/twml/contrib/layers/hashed_percentile_discretizer.py new file mode 100644 index 000000000..b32c3be8d --- /dev/null +++ b/twml/twml/contrib/layers/hashed_percentile_discretizer.py @@ -0,0 +1,217 @@ +# pylint: disable=no-member, attribute-defined-outside-init, too-many-instance-attributes +""" +Implementing HashedPercentileDiscretizer Layer +""" + + +from twitter.deepbird.util.hashing import ( + integer_multiplicative_hashing_uniform, + integer_multiplicative_hashing, +) # noqa: F401 + +from libtwml import percentile_discretizer_bin_indices +import numpy as np +import tensorflow.compat.v1 as tf +import twml +from twml.layers.layer import Layer +from twml.layers.partition import Partition +from twml.layers.stitch import Stitch + + +class HashedPercentileDiscretizer(Layer): + """ + HashedPercentileDiscretizer layer is constructed by PercentileDiscretizerCalibrator + after accumulating data + and performing minimum description length (PercentileDiscretizer) calibration. + + HashedPercentileDiscretizer takes sparse continuous features and converts then to sparse + binary features. Each binary output feature is associated to an HashedPercentileDiscretizer + bin. + Each HashedPercentileDiscretizer input feature is converted to n_bin bins. + Each HashedPercentileDiscretizer calibration tries to find bin delimiters such + that the number of features values + per bin is roughly equal (for each given HashedPercentileDiscretizer feature). + Note that if an input feature is rarely used, so will its associated output bin/features. + The difference between this layer and PercentileDiscretizer is that the + DeterministicPercentileDiscretize always assigns the same output id in the SparseTensor to the + same input feature id + bin. This is useful if you want to user transfer learning on pre-trained + sparse to dense embedding layers, but re-calibrate your discretizer on newer data. + """ + + def __init__(self, n_feature, n_bin, out_bits, + bin_values=None, hash_keys=None, hash_values=None, + bin_ids=None, feature_offsets=None, + hash_fn=integer_multiplicative_hashing_uniform, **kwargs): + """ + Creates a non-initialized `HashedPercentileDiscretizer` object. + Before using the table you will have to initialize it. After initialization + the table will be immutable. + + Parent class args: + see [tf.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/layers/Layer) + for documentation of parent class arguments. + + Required args: + n_feature: + number of unique features accumulated during HashedPercentileDiscretizer calibration. + This is the number of features in the hash map. + Used to initialize bin_values, hash_keys, hash_values, + bin_ids, bin_values and feature_offsets. + n_bin: + number of HashedPercentileDiscretizer bins used for + HashedPercentileDiscretizer calibration. Used to initialize bin_values, hash_keys, + hash_values, bin_ids, bin_values and feature_offsets. + out_bits: + Determines the maximum value for output feature IDs. + The dense_shape of the SparseTensor returned by lookup(x) + will be [x.shape[0], 1 << output_bits]. + + Optional args: + hash_keys: + contains the features ID that HashedPercentileDiscretizer discretizes and knows + about. The hash map (hash_keys->hash_values) is used for two reasons: + 1. divide inputs into two feature spaces: + HashedPercentileDiscretizer vs non-HashedPercentileDiscretizer + 2. transate the HashedPercentileDiscretizer features into a hash_feature ID that + HashedPercentileDiscretizer understands. + The hash_map is expected to contain n_feature items. + hash_values: + translates the feature IDs into hash_feature IDs for HashedPercentileDiscretizer. + bin_ids: + a 1D Tensor of size n_feature * n_bin + 1 which contains + unique IDs to which the HashedPercentileDiscretizer features will be translated to. + For example, tf.Tensor(np.arange(n_feature * n_bin)) would produce + the most efficient output space. + bin_values: + a 1D Tensor aligned with bin_ids. + For a given hash_feature ID j, it's value bin's are indexed between + `j*n_bin` and `j*n_bin + n_bin-1`. + As such, bin_ids[j*n_bin+i] is translated from a hash_feature ID of j + and a inputs value between + `bin_values[j*n_bin + i]` and `bin_values[j*n_bin+i+1]`. + feature_offsets: + a 1D Tensor specifying the starting location of bins for a given feature id. + For example, tf.Tensor(np.arange(0, bin_values.size, n_bin, dtype='int64')). + hash_fn: + a function that takes in `feature_ids`, `bucket_indices` and `output_size` and + hashes the bucketed features into the `output_size` buckets. The default uses knuth's + multiplicative hashing + """ + super(HashedPercentileDiscretizer, self).__init__(**kwargs) + + max_discretizer_feature = n_feature * (n_bin + 1) + self._n_feature = n_feature + self._n_bin = n_bin + + if not self.built: + self.build(input_shape=None) + + # build variables + self.output_size = tf.convert_to_tensor(1 << out_bits, tf.int64) + self._out_bits = out_bits + + hash_keys = hash_keys + if hash_keys is None: + hash_keys = np.empty(n_feature, dtype=np.int64) + + hash_values = hash_values + if hash_values is None: + hash_values = np.empty(n_feature, dtype=np.int64) + + initializer = tf.lookup.KeyValueTensorInitializer(hash_keys, hash_values) + self.hash_map = tf.lookup.StaticHashTable(initializer, -1) + self.bin_ids = bin_ids + if bin_ids is None: + bin_ids = np.empty(max_discretizer_feature, dtype=np.int64) + + self.bin_values = bin_values + if bin_values is None: + bin_values = np.empty(max_discretizer_feature, dtype=np.float32) + + self.feature_offsets = feature_offsets + if feature_offsets is None: + feature_offsets = np.empty(n_feature, dtype=np.int64) + + self.hash_fn = hash_fn + + def build(self, input_shape): # pylint: disable=unused-argument + """ + Creates the variables of the layer: + hash_keys, hash_values, bin_ids, bin_values, feature_offsets and self.output_size. + """ + # build layers + self.partition = Partition() + self.stitch = Stitch() + # make sure this is last + self.built = True + + def call(self, inputs, **kwargs): + """Looks up `keys` in a table, outputs the corresponding values. + + Implements HashedPercentileDiscretizer inference where inputs are intersected with a + hash_map. + Part of the inputs are discretized using twml.discretizer + to produce a discretizer_output SparseTensor. + This SparseTensor is then joined with the original inputs SparseTensor, + but only for the inputs keys that did not get discretized. + + Args: + inputs: A 2D SparseTensor that is input to HashedPercentileDiscretizer for + discretization. It has a dense_shape of [batch_size, input_size] + name: A name for the operation (optional). + Returns: + A `SparseTensor` of the same type as `inputs`. + Its dense_shape is [shape_input.dense_shape[0], 1 << output_bits]. + """ + if isinstance(inputs, tf.SparseTensor): + inputs = twml.SparseTensor.from_tf(inputs) + + assert(isinstance(inputs, twml.SparseTensor)) + + # sparse column indices + ids = inputs.ids + # sparse row indices + keys = inputs.indices + # sparse values + vals = inputs.values + + hashed_keys = self.hash_map.lookup(keys) + hashed_keys = tf.cast(hashed_keys, tf.int64) + + found = tf.not_equal(hashed_keys, tf.constant(-1, tf.int64)) + partition_ids = tf.cast(found, tf.int32) + + found = tf.reshape(found, [-1]) + continuous_feature_ids = tf.boolean_mask(keys, found) + + vals, key, indices = self.partition(partition_ids, vals, tf.where(found, hashed_keys, keys)) + non_discretizer_keys, discretizer_in_keys = key + non_discretizer_vals, discretizer_in_vals = vals + + non_discretizer_keys = twml.util.limit_bits(non_discretizer_keys, self._out_bits) + self.non_discretizer_keys = non_discretizer_keys + + # run HashedPercentileDiscretizer on the keys/values it knows about + output = percentile_discretizer_bin_indices(discretizer_in_keys, + discretizer_in_vals, + self.bin_ids, + self.bin_values, + self.feature_offsets) + discretizer_bucket_idxs, discretizer_vals = output + new_discretizer_keys = self.hash_fn(continuous_feature_ids, discretizer_bucket_idxs, + self.output_size) + # Stitch the keys and values from discretizer and non discretizer indices back, with help + # of the Stitch Layer + self.discretizer_out_keys = new_discretizer_keys + + concat_data = self.stitch([non_discretizer_vals, discretizer_vals], + [non_discretizer_keys, new_discretizer_keys], + indices) + + concat_vals, concat_keys = concat_data + + # Generate output shape using _compute_output_shape + + batch_size = tf.to_int64(inputs.dense_shape[0]) + output_shape = [batch_size, self.output_size] + return twml.SparseTensor(ids, concat_keys, concat_vals, output_shape).to_tf() diff --git a/twml/twml/contrib/layers/hashing_discretizer.py b/twml/twml/contrib/layers/hashing_discretizer.py new file mode 100644 index 000000000..2a8244f4b --- /dev/null +++ b/twml/twml/contrib/layers/hashing_discretizer.py @@ -0,0 +1,156 @@ +# pylint: disable=no-member, attribute-defined-outside-init, too-many-instance-attributes +""" +Implementing HashingDiscretizer Layer +""" + + +import libtwml +import tensorflow.compat.v1 as tf +import twml +from twml.constants import HashingDiscretizerOptions +from twml.layers.layer import Layer + + +class HashingDiscretizer(Layer): + """A layer that discretizes continuous features, with hashed feature assignments + + HashingDiscretizer converts sparse continuous features into sparse + binary features. Each binary output feature indicates the presence of a + value in a HashingDiscretizer bin. + + Each calibrated HashingDiscretizer input feature is converted to n_bin+1 bins. + + - n_bin bin boundaries for each feature (i.e. len(bin_vals[id])==n_bin) defines n_bin+1 bins + - bin assignment = sum(bin_vals 0: + # pass all inputs to the c++ op + # the op determines whether to discretize (when a feature is calibrated), + # or whether to simply limit bits and pass through (when not calibrated) + # NOTE - Hashing is done in C++ + discretizer_keys, discretizer_vals = libtwml.ops.hashing_discretizer( + input_ids=keys, # Input + input_vals=vals, # Input + bin_vals=self._bin_vals, # Input + feature_ids=tf.make_tensor_proto(self._feature_ids), # Attr + n_bin=self._n_bin, # Attr + output_bits=self._out_bits, # Attr + cost_per_unit=self.cost_per_unit, # Attr + options=self._options, # Attr + ) + else: + discretizer_keys = twml.util.limit_bits(keys, self._out_bits) + discretizer_vals = vals + + batch_size = tf.to_int64(inputs.dense_shape[0]) + output_size = tf.convert_to_tensor(1 << self._out_bits, tf.int64) + output_shape = [batch_size, output_size] + + return twml.SparseTensor(ids, discretizer_keys, discretizer_vals, output_shape).to_tf() diff --git a/twml/twml/contrib/layers/mask_layer.py b/twml/twml/contrib/layers/mask_layer.py new file mode 100644 index 000000000..f5e788c7b --- /dev/null +++ b/twml/twml/contrib/layers/mask_layer.py @@ -0,0 +1,29 @@ +from twml.contrib.pruning import apply_mask +from twml.layers import Layer + + +class MaskLayer(Layer): + """ + This layer corresponds to `twml.contrib.pruning.apply_mask`. + + It applies a binary mask to mask out channels of a given tensor. The masks can be + optimized using `twml.contrib.trainers.PruningDataRecordTrainer`. + """ + + def call(self, inputs, **kwargs): + """ + Applies a binary mask to the channels of the input. + + Arguments: + inputs: + input tensor + **kwargs: + additional keyword arguments + + Returns: + Masked tensor + """ + return apply_mask(inputs) + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/twml/twml/contrib/layers/stacked_rnn.py b/twml/twml/contrib/layers/stacked_rnn.py new file mode 100644 index 000000000..e05f5d853 --- /dev/null +++ b/twml/twml/contrib/layers/stacked_rnn.py @@ -0,0 +1,189 @@ + +from twitter.deepbird.compat.v1.rnn import stack_bidirectional_dynamic_rnn + +import tensorflow.compat.v1 as tf +import tensorflow +import twml + + +def _get_rnn_cell_creator(cell_type): + if cell_type == "LSTM": + Cell = tf.nn.rnn_cell.LSTMCell + elif cell_type == "GRU": + Cell = tf.nn.rnn_cell.GRUCell + else: + raise ValueError("cell_type: %s is not supported." + "It should be one of 'LSTM' or 'GRU'." % cell_type) + return Cell + + +def _apply_dropout_wrapper(rnn_cells, dropout): + """ Apply dropout wrapper around each cell if necessary """ + if rnn_cells is None: + return None + + cells = [] + for i, dropout_rate in enumerate(dropout): + cell = rnn_cells[i] + if dropout_rate > 0: + cell = tf.nn.rnn_cell.DropoutWrapper(cell, input_keep_prob=(1.0 - dropout_rate)) + cells.append(cell) + return cells + + +def _create_bidirectional_rnn_cell(num_units, dropout, cell_type): + scope_name = "lstm" if cell_type else "gru" + with tf.variable_scope(scope_name): + Cell = _get_rnn_cell_creator(cell_type) + cells_forward = [Cell(output_size) for output_size in num_units] + cells_backward = [Cell(output_size) for output_size in num_units] + cells_forward = _apply_dropout_wrapper(cells_forward, dropout) + cells_backward = _apply_dropout_wrapper(cells_backward, dropout) + + def stacked_rnn_cell(inputs, sequence_lengths): + with tf.variable_scope(scope_name): + outputs, final_states, _ = stack_bidirectional_dynamic_rnn( + cells_fw=cells_forward, cells_bw=cells_backward, inputs=inputs, + sequence_length=sequence_lengths, dtype=inputs.dtype) + return final_states[-1][-1] + + return stacked_rnn_cell + + +def _create_unidirectional_rnn_cell(num_units, dropout, cell_type): + scope_name = "lstm" if cell_type else "gru" + with tf.variable_scope(scope_name): + Cell = _get_rnn_cell_creator(cell_type) + cells = [Cell(output_size) for output_size in num_units] + cells = _apply_dropout_wrapper(cells, dropout) + multi_cell = tf.nn.rnn_cell.MultiRNNCell(cells) + + def stacked_rnn_cell(inputs, sequence_lengths): + with tf.variable_scope(scope_name): + outputs, final_states = tf.nn.static_rnn( + multi_cell, + tf.unstack(inputs, axis=1), + dtype=inputs.dtype, + sequence_length=sequence_lengths) + return final_states[-1].h + + return stacked_rnn_cell + + +def _create_regular_rnn_cell(num_units, dropout, cell_type, is_bidirectional): + if is_bidirectional: + return _create_bidirectional_rnn_cell(num_units, dropout, cell_type) + else: + return _create_unidirectional_rnn_cell(num_units, dropout, cell_type) + + +class StackedRNN(twml.layers.Layer): + """ + Layer for stacking RNN modules. + This layer provides a unified interface for RNN modules that perform well on CPUs and GPUs. + + Arguments: + num_units: + A list specifying the number of units per layer. + dropout: + Dropout applied to the input of each cell. + If list, has to dropout used for each layer. + If number, the same amount of dropout is used everywhere. + Defaults to 0. + is_training: + Flag to specify if the layer is used in training mode or not. + cell_type: + Sepcifies the type of RNN. Can be "LSTM". "GRU" is not yet implemented. + is_bidirectional: + Specifies if the stacked RNN layer is bidirectional. + This is for forward compatibility, this is not yet implemented. + Defaults to False. + """ + + def __init__(self, + num_units, + dropout=0, + is_training=True, + cell_type="LSTM", + is_bidirectional=False, + name="stacked_rnn"): + + super(StackedRNN, self).__init__(name=name) + + if (is_bidirectional): + raise NotImplementedError("Bidirectional RNN is not yet implemented") + + if (cell_type != "LSTM"): + raise NotImplementedError("Only LSTMs are supported") + + if not isinstance(num_units, (list, tuple)): + num_units = [num_units] + else: + num_units = num_units + + self.num_layers = len(num_units) + if not isinstance(dropout, (tuple, list)): + dropout = [dropout] * self.num_layers + else: + dropout = dropout + + self.is_training = is_training + + is_gpu_available = twml.contrib.utils.is_gpu_available() + same_unit_size = all(size == num_units[0] for size in num_units) + same_dropout_rate = any(val == dropout[0] for val in dropout) + + self.stacked_rnn_cell = None + self.num_units = num_units + self.dropout = dropout + self.cell_type = cell_type + self.is_bidirectional = is_bidirectional + + def build(self, input_shape): + self.stacked_rnn_cell = _create_regular_rnn_cell(self.num_units, + self.dropout, + self.cell_type, + self.is_bidirectional) + + def call(self, inputs, sequence_lengths): + """ + Arguments: + inputs: + A tensor of size [batch_size, max_sequence_length, embedding_size]. + sequence_lengths: + The length of each input sequence in the batch. Should be of size [batch_size]. + Returns: + final_output + The output of at the end of sequence_length. + """ + return self.stacked_rnn_cell(inputs, sequence_lengths) + + +def stacked_rnn(inputs, sequence_lengths, num_units, + dropout=0, is_training=True, + cell_type="LSTM", is_bidirectional=False, name="stacked_rnn"): + """Functional interface for StackedRNN + Arguments: + inputs: + A tensor of size [batch_size, max_sequence_length, embedding_size]. + sequence_lengths: + The length of each input sequence in the batch. Should be of size [batch_size]. + num_units: + A list specifying the number of units per layer. + dropout: + Dropout applied to the input of each cell. + If list, has to dropout used for each layer. + If number, the same amount of dropout is used everywhere. + Defaults to 0. + is_training: + Flag to specify if the layer is used in training mode or not. + cell_type: + Sepcifies the type of RNN. Can be "LSTM" or "GRU". + is_bidirectional: + Specifies if the stacked RNN layer is bidirectional. + Defaults to False. + Returns + outputs, state. + """ + rnn = StackedRNN(num_units, dropout, is_training, cell_type, is_bidirectional, name) + return rnn(inputs, sequence_lengths) diff --git a/twml/twml/contrib/layers/zscore_normalization.py b/twml/twml/contrib/layers/zscore_normalization.py new file mode 100644 index 000000000..8a1064965 --- /dev/null +++ b/twml/twml/contrib/layers/zscore_normalization.py @@ -0,0 +1,247 @@ +""" +Contains the twml.layers.ZscoreNormalization layer. +""" +from twml.layers.layer import Layer +import tensorflow.compat.v1 as tf + +from tensorflow.python.training import moving_averages + + +# This is copied from tensorflow.contrib.framework.python.ops.add_model_variable in 1.15 +# Not available in 2.x +# TODO: Figure out if this is really necessary. +def _add_model_variable(var): + """Adds a variable to the `GraphKeys.MODEL_VARIABLES` collection. + Args: + var: a variable. + """ + if var not in tf.get_collection(tf.GraphKeys.MODEL_VARIABLES): + tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, var) + + +def update_moving_variable(batch_var, moving_var, decay, zero_debias=True, name=None): + update_op = moving_averages.assign_moving_average( + moving_var, batch_var, decay, zero_debias=zero_debias, name=None) + _add_model_variable(moving_var) + with tf.control_dependencies([update_op]): + return tf.identity(moving_var) + + +class ZscoreNormalization(Layer): + """ + Perform z-score normalization using moving mean and std. + Missing values are not included during mean/std calculation + This layer should only be used right after input layer. + + Args: + decay: + using large decay to include longer moving means. + data_type: + use float64 to prevent overflow during variance calculation. + name: + Layer name + Returns: + A layer representing the output of the ZscoreNormalization transformation. + """ + + def __init__( + self, + decay=0.9999, + data_type=tf.float64, + name=None, + **kwargs): + super(ZscoreNormalization, self).__init__(name=name, **kwargs) + self.epsilon = tf.constant(1., data_type) + self.decay = decay + self.data_type = data_type + + def build(self, input_shape): # pylint: disable=unused-argument + """Creates the moving_mean and moving_var tf.Variables of the layer.""" + input_dim = input_shape[1] + self.moving_mean = self.add_variable( + '{}_mean/EMA'.format(self.name), + initializer=tf.constant_initializer(), + shape=[input_dim], + dtype=self.data_type, + trainable=False + ) + self.moving_var = self.add_variable( + '{}_variance/EMA'.format(self.name), + initializer=tf.constant_initializer(), + shape=[input_dim], + dtype=self.data_type, + trainable=False + ) + self.built = True + + def compute_output_shape(self, input_shape): + """Computes the output shape of the layer given the input shape. + + Args: + input_shape: A (possibly nested tuple of) `TensorShape`. It need not + be fully defined (e.g. the batch size may be unknown). + + """ + + return input_shape + + def _training_pass(self, input, dense_mask, input_dtype, handle_single, zero_debias): + epsilon = self.epsilon + moving_mean, moving_var = self.moving_mean, self.moving_var + # calculate the number of exisiting value for each feature + tensor_batch_num = tf.reduce_sum(tf.cast(dense_mask, self.data_type), axis=0) + mask_ones = tf.cast(tensor_batch_num, tf.bool) + eps_vector = tf.fill(tf.shape(tensor_batch_num), epsilon) + # the following filled 0 with epision + tensor_batch_num_eps = tf.where(mask_ones, + tensor_batch_num, + eps_vector + ) + tensor_batch_num_eps_broacast = tf.expand_dims(tensor_batch_num_eps, 0) + tensor_batch_divided = input / tensor_batch_num_eps_broacast + tensor_batch_mean = tf.reduce_sum(tensor_batch_divided, axis=0) + + # update moving mean here, and use it to calculate the std. + tensor_moving_mean = update_moving_variable(tensor_batch_mean, moving_mean, self.decay, + zero_debias, name="mean_ema_op") + + tensor_batch_sub_mean = input - tf.expand_dims(tensor_moving_mean, 0) + tensor_batch_sub_mean = tf.where(dense_mask, + tensor_batch_sub_mean, + tf.zeros_like(tensor_batch_sub_mean)) + # divided by sqrt(n) before square, and then do summation for numeric stability. + broad_sqrt_num_eps = tf.expand_dims(tf.sqrt(tensor_batch_num_eps), 0) + tensor_batch_sub_mean_div = tensor_batch_sub_mean / broad_sqrt_num_eps + tensor_batch_sub_mean_div_square = tf.square(tensor_batch_sub_mean_div) + tensor_batch_var = tf.reduce_sum(tensor_batch_sub_mean_div_square, axis=0) + + # update moving var here, dont replace 0 with eps before updating. + tensor_moving_var = update_moving_variable(tensor_batch_var, moving_var, self.decay, + zero_debias, name="var_ema_op") + + # if std is 0, replace it with epsilon + tensor_moving_std = tf.sqrt(tensor_moving_var) + tensor_moving_std_eps = tf.where(tf.equal(tensor_moving_std, 0), + eps_vector, + tensor_moving_std) + + missing_input_norm = tensor_batch_sub_mean / tf.expand_dims(tensor_moving_std_eps, 0) + + if handle_single: + # if std==0 and value not missing, reset it to 1. + moving_var_mask_zero = tf.math.equal(tensor_moving_var, 0) + moving_var_mask_zero = tf.expand_dims(moving_var_mask_zero, 0) + missing_input_norm = tf.where( + tf.math.logical_and(dense_mask, moving_var_mask_zero), + tf.ones_like(missing_input_norm), + missing_input_norm + ) + if input_dtype != self.data_type: + missing_input_norm = tf.cast(missing_input_norm, input_dtype) + return missing_input_norm + + def _infer_pass(self, input, dense_mask, input_dtype, handle_single): + epsilon = tf.cast(self.epsilon, input_dtype) + testing_moving_mean = tf.cast(self.moving_mean, input_dtype) + tensor_moving_std = tf.cast(tf.sqrt(self.moving_var), input_dtype) + + broad_mean = tf.expand_dims(testing_moving_mean, 0) + tensor_batch_sub_mean = input - broad_mean + + tensor_batch_sub_mean = tf.where(dense_mask, + tensor_batch_sub_mean, + tf.zeros_like(tensor_batch_sub_mean) + ) + tensor_moving_std_eps = tf.where(tf.equal(tensor_moving_std, 0), + tf.fill(tf.shape(tensor_moving_std), epsilon), + tensor_moving_std) + missing_input_norm = tensor_batch_sub_mean / tf.expand_dims(tensor_moving_std_eps, 0) + if handle_single: + # if std==0 and value not missing, reset it to 1. + moving_var_broad = tf.expand_dims(tensor_moving_std, 0) + moving_var_mask_zero = tf.math.logical_not(tf.cast(moving_var_broad, tf.bool)) + + missing_input_norm = tf.where(tf.math.logical_and(dense_mask, moving_var_mask_zero), + tf.ones_like(missing_input_norm), + missing_input_norm + ) + return missing_input_norm + + def call( + self, + input, + is_training, + dense_mask=None, + zero_debias=True, + handle_single=False): + """ + Args: + ----------- + input: B x D : float32/float64 + missing value must be set to 0. + is_training: bool + training phase or testing phase + dense_mask: B x D : bool + missing value should be marked as 0, non-missing as 1. same shape as input + zero_debias: bool + bias correction of the moving average. (biased towards 0 in the beginning. + see adam paper. https://arxiv.org/abs/1412.6980) + handle_single: bool + if std==0, and feature is not missing value, set the value to 1, instead of 0. + This is super rare if input only consists of continous feature. + But if one-hot feature is included, + they will all have same values 1, in that case, make sure to set handle_single to true. + """ + + if dense_mask is None: + dense_mask = tf.math.logical_not(tf.equal(input, 0)) + input_dtype = input.dtype + + if is_training: + if input_dtype != self.data_type: + input = tf.cast(input, self.data_type) + return self._training_pass(input, dense_mask, input_dtype, handle_single, zero_debias) + else: + return self._infer_pass(input, dense_mask, input_dtype, handle_single) + + +def zscore_normalization( + input, + is_training, + decay=0.9999, + data_type=tf.float64, + name=None, + dense_mask=None, + zero_debias=True, + handle_single=False, **kwargs): + """ + Args: + ------------ + input: B x D : float32/float64 + missing value must be set to 0. + is_training: bool + training phase or testing phase + decay: + using large decay to include longer moving means. + data_type: + use float64 to zprevent overflow during variance calculation. + name: + Layer name + dense_mask: B x D : bool + missing value should be marked as 0, non-missing as 1. same shape as input + zero_debias: bool + bias correction of the moving average. (biased towards 0 in the beginning. + see adam paper. https://arxiv.org/abs/1412.6980) + handle_single: bool + if std==0, and feature is not missing value, set the value to 1, instead of 0. + This is super rare if input only consists of continous feature. + But if one-hot feature is included, + they will all have same values 1, in that case, make sure to set handle_single to true. + """ + + norm_layer = ZscoreNormalization(decay=decay, data_type=data_type, name=name, **kwargs) + return norm_layer(input, + is_training, + dense_mask=dense_mask, + zero_debias=zero_debias, + handle_single=handle_single) diff --git a/twml/twml/contrib/metrics/__init__.py b/twml/twml/contrib/metrics/__init__.py new file mode 100644 index 000000000..37e6563c9 --- /dev/null +++ b/twml/twml/contrib/metrics/__init__.py @@ -0,0 +1,5 @@ +# pylint: disable=wildcard-import +"""This module contains experimental metric(s) for search and ranking""" + +from .search_metrics import get_search_metric_fn, ndcg # noqa: F401 +from .metrics import * # noqa: F401 diff --git a/twml/twml/contrib/metrics/metrics.py b/twml/twml/contrib/metrics/metrics.py new file mode 100644 index 000000000..dea1a5273 --- /dev/null +++ b/twml/twml/contrib/metrics/metrics.py @@ -0,0 +1,209 @@ +""" +Module containing extra tensorflow metrics used at Twitter. +This module conforms to conventions used by tf.metrics.*. +In particular, each metric constructs two subgraphs: value_op and update_op: + - The value op is used to fetch the current metric value. + - The update_op is used to accumulate into the metric. + +Note: similar to tf.metrics.*, metrics in here do not support multi-label learning. +We will have to write wrapper classes to create one metric per label. + +Note: similar to tf.metrics.*, batches added into a metric via its update_op are cumulative! + +""" + +from collections import OrderedDict + +import tensorflow.compat.v1 as tf +from twml.metrics import get_multi_binary_class_metric_fn + + + +# checkstyle: noqa +def get_partial_multi_binary_class_metric_fn(metrics, classes=None, class_dim=1, predcols=None): + + def get_eval_metric_ops(graph_output, labels, weights): + if predcols is None: + preds = graph_output['output'] + else: + if isinstance(predcols, int): + predcol_list=[predcols] + else: + predcol_list=list(predcols) + for col in predcol_list: + assert 0 <= col < graph_output['output'].shape[class_dim], 'Invalid Prediction Column Index !' + preds = tf.gather(graph_output['output'], indices=predcol_list, axis=class_dim) # [batchSz, num_col] + labels = tf.gather(labels, indices=predcol_list, axis=class_dim) # [batchSz, num_col] + + predInfo = {'output': preds} + if 'threshold' in graph_output: + predInfo['threshold'] = graph_output['threshold'] + if 'hard_output' in graph_output: + predInfo['hard_output'] = graph_output['hard_output'] + + metrics_op = get_multi_binary_class_metric_fn(metrics, classes, class_dim) + metrics_op_res = metrics_op(predInfo, labels, weights) + return metrics_op_res + + return get_eval_metric_ops + + + +# Numeric Prediction Performance among TopK Predictions +def mean_numeric_label_topK(labels, predictions, weights, name, topK_id): + top_k_labels = tf.gather(params=labels, indices=topK_id, axis=0) # [topK, 1] + return tf.metrics.mean(values=top_k_labels, name=name) + +def mean_gated_numeric_label_topK(labels, predictions, weights, name, topK_id, bar=2.0): + assert isinstance(bar, int) or isinstance(bar, float), "bar must be int or float" + top_k_labels = tf.gather(params=labels, indices=topK_id, axis=0) # [topK, 1] + gated_top_k_labels = tf.cast(top_k_labels > bar*1.0, tf.int32) + return tf.metrics.mean(values=gated_top_k_labels, name=name) + +SUPPORTED_NUMERIC_METRICS = { + 'mean_numeric_label_topk': mean_numeric_label_topK, + 'mean_gated_numeric_label_topk': mean_gated_numeric_label_topK +} +DEFAULT_NUMERIC_METRICS = ['mean_numeric_label_topk', 'mean_gated_numeric_label_topk'] + + + +def get_metric_topK_fn_helper(targetMetrics, supportedMetrics_op, metrics=None, topK=(5,5,5), predcol=None, labelcol=None): + """ + :param targetMetrics: Target Metric List + :param supportedMetrics_op: Supported Metric Operators Dict + :param metrics: Metric Set to evaluate + :param topK: (topK_min, topK_max, topK_delta) Tuple + :param predcol: Prediction Column Index + :param labelcol: Label Column Index + :return: + """ + # pylint: disable=dict-keys-not-iterating + if targetMetrics is None or supportedMetrics_op is None: + raise ValueError("Invalid Target Metric List/op !") + + targetMetrics = set([m.lower() for m in targetMetrics]) + if metrics is None: + metrics = list(targetMetrics) + else: + metrics = [m.lower() for m in metrics if m.lower() in targetMetrics] + + num_k = int((topK[1]-topK[0])/topK[2]+1) + topK_list = [topK[0]+d*topK[2] for d in range(num_k)] + if 1 not in topK_list: + topK_list = [1] + topK_list + + + def get_eval_metric_ops(graph_output, labels, weights): + """ + graph_output: + dict that is returned by build_graph given input features. + labels: + target labels associated to batch. + weights: + weights of the samples.. + """ + eval_metric_ops = OrderedDict() + + if predcol is None: + pred = graph_output['output'] + else: + assert 0 <= predcol < graph_output['output'].shape[1], 'Invalid Prediction Column Index !' + assert labelcol is not None + pred = tf.reshape(graph_output['output'][:, predcol], shape=[-1, 1]) + labels = tf.reshape(labels[:, labelcol], shape=[-1, 1]) + numOut = graph_output['output'].shape[1] + pred_score = tf.reshape(graph_output['output'][:, numOut-1], shape=[-1, 1]) + + # add metrics to eval_metric_ops dict + for metric_name in metrics: + metric_name = metric_name.lower() # metric name are case insensitive. + + if metric_name in supportedMetrics_op: + metric_factory = supportedMetrics_op.get(metric_name) + + if 'topk' not in metric_name: + value_op, update_op = metric_factory( + labels=labels, + predictions=pred, + weights=weights, + name=metric_name) + eval_metric_ops[metric_name] = (value_op, update_op) + else: + for K in topK_list: + K_min = tf.minimum(K, tf.shape(pred_score)[0]) + topK_id = tf.nn.top_k(tf.reshape(pred_score, shape=[-1]), k=K_min)[1] # [topK] + value_op, update_op = metric_factory( + labels=labels, + predictions=pred, + weights=weights, + name=metric_name+'__k_'+str(K), + topK_id=topK_id) + eval_metric_ops[metric_name+'__k_'+str(K)] = (value_op, update_op) + + else: + raise ValueError('Cannot find the metric named ' + metric_name) + + return eval_metric_ops + + return get_eval_metric_ops + + + +def get_numeric_metric_fn(metrics=None, topK=(5,5,5), predcol=None, labelcol=None): + if metrics is None: + metrics = list(DEFAULT_NUMERIC_METRICS) + metrics = list(set(metrics)) + + metric_op = get_metric_topK_fn_helper(targetMetrics=list(DEFAULT_NUMERIC_METRICS), + supportedMetrics_op=SUPPORTED_NUMERIC_METRICS, + metrics=metrics, topK=topK, predcol=predcol, labelcol=labelcol) + return metric_op + + + +def get_single_binary_task_metric_fn(metrics, classnames, topK=(5,5,5), use_topK=False): + """ + graph_output['output']: [BatchSz, 1] [pred_Task1] + labels: [BatchSz, 2] [Task1, NumericLabel] + """ + def get_eval_metric_ops(graph_output, labels, weights): + metric_op_base = get_partial_multi_binary_class_metric_fn(metrics, predcols=0, classes=classnames) + classnames_unw = ['unweighted_'+cs for cs in classnames] + metric_op_unw = get_partial_multi_binary_class_metric_fn(metrics, predcols=0, classes=classnames_unw) + + metrics_base_res = metric_op_base(graph_output, labels, weights) + metrics_unw_res = metric_op_unw(graph_output, labels, None) + metrics_base_res.update(metrics_unw_res) + + if use_topK: + metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=topK, predcol=0, labelcol=1) + metrics_numeric_res = metric_op_numeric(graph_output, labels, weights) + metrics_base_res.update(metrics_numeric_res) + return metrics_base_res + + return get_eval_metric_ops + + +def get_dual_binary_tasks_metric_fn(metrics, classnames, topK=(5,5,5), use_topK=False): + """ + graph_output['output']: [BatchSz, 3] [pred_Task1, pred_Task2, Score] + labels: [BatchSz, 3] [Task1, Task2, NumericLabel] + """ + def get_eval_metric_ops(graph_output, labels, weights): + + metric_op_base = get_partial_multi_binary_class_metric_fn(metrics, predcols=[0, 1], classes=classnames) + classnames_unw = ['unweighted_'+cs for cs in classnames] + metric_op_unw = get_partial_multi_binary_class_metric_fn(metrics, predcols=[0, 1], classes=classnames_unw) + + metrics_base_res = metric_op_base(graph_output, labels, weights) + metrics_unw_res = metric_op_unw(graph_output, labels, None) + metrics_base_res.update(metrics_unw_res) + + if use_topK: + metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=topK, predcol=2, labelcol=2) + metrics_numeric_res = metric_op_numeric(graph_output, labels, weights) + metrics_base_res.update(metrics_numeric_res) + return metrics_base_res + + return get_eval_metric_ops diff --git a/twml/twml/contrib/metrics/search_metrics.py b/twml/twml/contrib/metrics/search_metrics.py new file mode 100644 index 000000000..7d7a502f1 --- /dev/null +++ b/twml/twml/contrib/metrics/search_metrics.py @@ -0,0 +1,292 @@ +""" +Module containing extra tensorflow metrics used at Twitter. +This module conforms to conventions used by tf.metrics.*. +In particular, each metric constructs two subgraphs: value_op and update_op: + - The value op is used to fetch the current metric value. + - The update_op is used to accumulate into the metric. + +Note: similar to tf.metrics.*, metrics in here do not support multi-label learning. +We will have to write wrapper classes to create one metric per label. + +Note: similar to tf.metrics.*, batches added into a metric via its update_op are cumulative! + +""" + +from collections import OrderedDict +from functools import partial + +import tensorflow.compat.v1 as tf +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes, ops +from tensorflow.python.ops import array_ops, state_ops +import twml +from twml.contrib.utils import math_fns + + +def ndcg(labels, predictions, + metrics_collections=None, + updates_collections=None, + name=None, + top_k_int=1): + # pylint: disable=unused-argument + """ + Compute full normalized discounted cumulative gain (ndcg) based on predictions + ndcg = dcg_k/idcg_k, k is a cut off ranking postion + There are a few variants of ndcg + The dcg (discounted cumulative gain) formula used in + twml.contrib.metrics.ndcg is:: + + \\sum_{i=1}^k \frac{2^{relevance\\_score} -1}{\\log_{2}(i + 1)} + + k is the length of items to be ranked in a batch/query + Notice that whether k will be replaced with a fixed value requires discussions + The scores in predictions are transformed to order and relevance scores to calculate ndcg + A relevance score means how relevant a DataRecord is to a particular query + + Arguments: + labels: the ground truth value. + predictions: the predicted values, whose shape must match labels. Ignored for CTR computation. + metrics_collections: optional list of collections to add this metric into. + updates_collections: optional list of collections to add the associated update_op into. + name: an optional variable_scope name. + + Returns: + ndcg: A `Tensor` representing the ndcg score. + update_op: A update operation used to accumulate data into this metric. + """ + with tf.variable_scope(name, 'ndcg', (labels, predictions)): + label_scores = tf.to_float(labels, name='label_to_float') + predicted_scores = tf.to_float(predictions, name='predictions_to_float') + + if context.executing_eagerly(): + raise RuntimeError('ndcg is not supported when eager execution ' + 'is enabled.') + + total_ndcg = _metric_variable([], dtypes.float32, name='total_ndcg') + count_query = _metric_variable([], dtypes.float32, name='query_count') + + # actual ndcg cutoff position top_k_int + max_prediction_size = array_ops.size(predicted_scores) + top_k_int = tf.minimum(max_prediction_size, top_k_int) + # the ndcg score of the batch + ndcg = math_fns.cal_ndcg(label_scores, + predicted_scores, top_k_int=top_k_int) + # add ndcg of the current batch to total_ndcg + update_total_op = state_ops.assign_add(total_ndcg, ndcg) + with ops.control_dependencies([ndcg]): + # count_query stores the number of queries + # count_query increases by 1 for each batch/query + update_count_op = state_ops.assign_add(count_query, 1) + + mean_ndcg = math_fns.safe_div(total_ndcg, count_query, 'mean_ndcg') + update_op = math_fns.safe_div(update_total_op, update_count_op, 'update_mean_ndcg_op') + + if metrics_collections: + ops.add_to_collections(metrics_collections, mean_ndcg) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return mean_ndcg, update_op + + +# Copied from metrics_impl.py with minor modifications. +# https://github.com/tensorflow/tensorflow/blob/v1.5.0/tensorflow/python/ops/metrics_impl.py#L39 +def _metric_variable(shape, dtype, validate_shape=True, name=None): + """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES`) collections.""" + + return tf.Variable( + lambda: tf.zeros(shape, dtype), + trainable=False, + collections=[tf.GraphKeys.LOCAL_VARIABLES, tf.GraphKeys.METRIC_VARIABLES], + validate_shape=validate_shape, + name=name) + + +# binary metric_name: (metric, requires thresholded output) +SUPPORTED_BINARY_CLASS_METRICS = { + # TWML binary metrics + 'rce': (twml.metrics.rce, False), + 'nrce': (partial(twml.metrics.rce, normalize=True), False), + # CTR measures positive sample ratio. This terminology is inherited from Ads. + 'ctr': (twml.metrics.ctr, False), + # predicted CTR measures predicted positive ratio. + 'predicted_ctr': (twml.metrics.predicted_ctr, False), + # thresholded metrics + 'accuracy': (tf.metrics.accuracy, True), + 'precision': (tf.metrics.precision, True), + 'recall': (tf.metrics.recall, True), + # tensorflow metrics + 'roc_auc': (partial(tf.metrics.auc, curve='ROC'), False), + 'pr_auc': (partial(tf.metrics.auc, curve='PR'), False), +} + +# search metric_name: metric +SUPPORTED_SEARCH_METRICS = { + # TWML search metrics + # ndcg needs the raw prediction scores to sort + 'ndcg': ndcg, +} + + +def get_search_metric_fn(binary_metrics=None, search_metrics=None, + ndcg_top_ks=[1, 3, 5, 10], use_binary_metrics=False): + """ + Returns a function having signature: + + .. code-block:: python + + def get_eval_metric_ops(graph_output, labels, weights): + ... + return eval_metric_ops + + where the returned eval_metric_ops is a dict of common evaluation metric + Ops for ranking. See `tf.estimator.EstimatorSpec + `_ + for a description of eval_metric_ops. The graph_output is a the result + dict returned by build_graph. Labels and weights are tf.Tensors. + + The following graph_output keys are recognized: + output: + the raw predictions. Required. + threshold: + Only used in SUPPORTED_BINARY_CLASS_METRICS + If the lables are 0s and 1s + A value between 0 and 1 used to threshold the output into a hard_output. + Defaults to 0.5 when threshold and hard_output are missing. + Either threshold or hard_output can be provided, but not both. + hard_output: + Only used in SUPPORTED_BINARY_CLASS_METRICS + A thresholded output. Either threshold or hard_output can be provided, but not both. + + Arguments: + only used in pointwise learning-to-rank + + binary_metrics (list of String): + a list of metrics of interest. E.g. ['ctr', 'accuracy', 'rce'] + These metrics are evaluated and reported to tensorboard *during the eval phases only*. + Supported metrics: + - ctr (same as positive sample ratio.) + - rce (cross entropy loss compared to the baseline model of always predicting ctr) + - nrce (normalized rce, do not use this one if you do not understand what it is) + - pr_auc + - roc_auc + - accuracy (percentage of predictions that are correct) + - precision (true positives) / (true positives + false positives) + - recall (true positives) / (true positives + false negatives) + + NOTE: accuracy / precision / recall apply to binary classification problems only. + I.e. a prediction is only considered correct if it matches the label. E.g. if the label + is 1.0, and the prediction is 0.99, it does not get credit. If you want to use + precision / recall / accuracy metrics with soft predictions, you'll need to threshold + your predictions into hard 0/1 labels. + + When binary_metrics is None (the default), it defaults to all supported metrics + + search_metrics (list of String): + a list of metrics of interest. E.g. ['ndcg'] + These metrics are evaluated and reported to tensorboard *during the eval phases only*. + Supported metrics: + - ndcg + + NOTE: ndcg works for ranking-relatd problems. + A batch contains all DataRecords that belong to the same query + If pair_in_batch_mode used in scalding -- a batch contains a pair of DataRecords + that belong to the same query and have different labels -- ndcg does not apply in here. + + When search_metrics is None (the default), it defaults to all supported search metrics + currently only 'ndcg' + + ndcg_top_ks (list of integers): + The cut-off ranking postions for a query + When ndcg_top_ks is None or empty (the default), it defaults to [1, 3, 5, 10] + + use_binary_metrics: + False (default) + Only set it to true in pointwise learning-to-rank + """ + # pylint: disable=dict-keys-not-iterating + + if ndcg_top_ks is None or not ndcg_top_ks: + ndcg_top_ks = [1, 3, 5, 10] + + if search_metrics is None: + search_metrics = list(SUPPORTED_SEARCH_METRICS.keys()) + + if binary_metrics is None and use_binary_metrics: + # Added SUPPORTED_BINARY_CLASS_METRICS in twml.metics as well + # they are only used in pointwise learing-to-rank + binary_metrics = list(SUPPORTED_BINARY_CLASS_METRICS.keys()) + + def get_eval_metric_ops(graph_output, labels, weights): + """ + graph_output: + dict that is returned by build_graph given input features. + labels: + target labels associated to batch. + weights: + weights of the samples.. + """ + + eval_metric_ops = OrderedDict() + + preds = graph_output['output'] + + threshold = graph_output['threshold'] if 'threshold' in graph_output else 0.5 + + hard_preds = graph_output.get('hard_output') + # hard_preds is a tensor + # check hard_preds is None and then check if it is empty + if hard_preds is None or tf.equal(tf.size(hard_preds), 0): + hard_preds = tf.greater_equal(preds, threshold) + + # add search metrics to eval_metric_ops dict + for metric_name in search_metrics: + metric_name = metric_name.lower() # metric name are case insensitive. + + if metric_name in eval_metric_ops: + # avoid adding duplicate metrics. + continue + + search_metric_factory = SUPPORTED_SEARCH_METRICS.get(metric_name) + if search_metric_factory: + if metric_name == 'ndcg': + for top_k in ndcg_top_ks: + # metric name will show as ndcg_1, ndcg_10, ... + metric_name_ndcg_top_k = metric_name + '_' + str(top_k) + top_k_int = tf.constant(top_k, dtype=tf.int32) + # Note: having weights in ndcg does not make much sense + # Because ndcg already has position weights/discounts + # Thus weights are not applied in ndcg metric + value_op, update_op = search_metric_factory( + labels=labels, + predictions=preds, + name=metric_name_ndcg_top_k, + top_k_int=top_k_int) + eval_metric_ops[metric_name_ndcg_top_k] = (value_op, update_op) + else: + raise ValueError('Cannot find the search metric named ' + metric_name) + + if use_binary_metrics: + # add binary metrics to eval_metric_ops dict + for metric_name in binary_metrics: + + if metric_name in eval_metric_ops: + # avoid adding duplicate metrics. + continue + + metric_name = metric_name.lower() # metric name are case insensitive. + binary_metric_factory, requires_threshold = SUPPORTED_BINARY_CLASS_METRICS.get(metric_name) + if binary_metric_factory: + value_op, update_op = binary_metric_factory( + labels=labels, + predictions=(hard_preds if requires_threshold else preds), + weights=weights, + name=metric_name) + eval_metric_ops[metric_name] = (value_op, update_op) + else: + raise ValueError('Cannot find the binary metric named ' + metric_name) + + return eval_metric_ops + + return get_eval_metric_ops diff --git a/twml/twml/contrib/optimizers/__init__.py b/twml/twml/contrib/optimizers/__init__.py new file mode 100644 index 000000000..112b2b410 --- /dev/null +++ b/twml/twml/contrib/optimizers/__init__.py @@ -0,0 +1,4 @@ +# pylint: disable=wildcard-import +"""This module contains experimental optimizer classes""" +from .deep_gradient_compression_optimizer import DeepGradientCompressionOptimizer # noqa: F401 +from .pruning_optimizer import PruningOptimizer # noqa: F401 diff --git a/twml/twml/contrib/optimizers/deep_gradient_compression_optimizer.py b/twml/twml/contrib/optimizers/deep_gradient_compression_optimizer.py new file mode 100644 index 000000000..2c71ed13f --- /dev/null +++ b/twml/twml/contrib/optimizers/deep_gradient_compression_optimizer.py @@ -0,0 +1,180 @@ +""" +A custom optimizer to implement Deep Gradient Compression. The general idea of +gradient compression is to compress the gradients exchanged across machines, +in order to reduce the communication overhead of distributing computing efforts. +More details in https://arxiv.org/abs/1712.01887 +""" + +# TODO: Test how much communication overhead this DeepGradientCompressionOptimizer can reduce under +# multi-GPU and distributed setting. + +import tensorflow.compat.v1 as tf + + +def compute_threshold(grad, density): + """ + A utility function to compute the threshold for gradient sparsification, given the gradient + tensor and the density. + Args: + grad(tf.Tensor): + Gradient tensor for some variable. + density(float): + Density degree when sparsifying gradients. + Returns(float): + Threshold for gradient sparsification. + """ + flat_grad = tf.reshape(grad, [-1]) + abs_flat_grad = tf.abs(flat_grad) + size = tf.shape(abs_flat_grad)[0] + k = tf.maximum(tf.constant(1), + tf.cast(tf.scalar_mul(density, tf.cast(size, tf.float32)), tf.int32)) + topk, _ = tf.nn.top_k(abs_flat_grad, k, False) + return topk[-1] + + +def get_top_row_indices(values, density): + """ + A utility function to get indices of most significant rows, given the density degree. + Args: + values(tf.Tensor): + Gradient or locally accumulated gradient for some variable. + density(float): + Density degree when filtering out rows. + Returns(list(int)): + Indices of most significant rows. + """ + abs_values = tf.abs(values) + + try: + row_num = tf.shape(abs_values)[0] + k = tf.maximum(tf.constant(1), + tf.cast(tf.scalar_mul(density, tf.cast(row_num, tf.float32)), tf.int32)) + row_sums = tf.squeeze(tf.reduce_sum(values, axis=1, keepdims=True)) + _, top_row_indices = tf.nn.top_k(row_sums, k=k, sorted=False) + # print "abs_values", abs_values, "row_sums", row_sums + return top_row_indices + # return tf.range(row_num) + + except ValueError: # if the tensor is 0-D or 1-D + return None + + +class DeepGradientCompressionOptimizer(tf.train.GradientDescentOptimizer): + """ + A custom optimizer to implement Deep Gradient Compression (https://arxiv.org/abs/1712.01887). + """ + + def __init__(self, learning_rate, use_locking=False, name="Sparse", + density=1.0, + density_decay=False, + density_decay_steps=10000, + density_decay_rate=0.5, + min_density=0.1, + accumulation=False): + super(DeepGradientCompressionOptimizer, self).__init__(learning_rate, use_locking, name) + self._initial_density_t = tf.convert_to_tensor(density) + self._density_decay = density_decay + dtype = self._initial_density_t.dtype + self._density_decay_steps_t = tf.convert_to_tensor(density_decay_steps, dtype) + self._density_decay_rate_t = tf.convert_to_tensor(density_decay_rate, dtype) + self._min_density_t = tf.convert_to_tensor(min_density, dtype) + self._accumulation = accumulation + + def _prepare(self): + super(DeepGradientCompressionOptimizer, self)._prepare() + if not self._density_decay: + self._density_t = self._initial_density_t + else: + dtype = self._initial_density_t.dtype + global_step = tf.cast(tf.train.get_global_step(), dtype) + p = tf.floor(tf.divide(global_step, self._density_decay_steps_t)) + decayed_density = tf.multiply(self._initial_density_t, + tf.pow(self._density_decay_rate_t, p)) + self._density_t = tf.maximum(self._min_density_t, decayed_density) + + def _create_slots(self, var_list): + """ + Create a slot variable to accumulate gradients locally for each variable in `var_list`. + Args: + var_list(list(tf.Variable)): + List of variables to accumulate gradients locally for. + """ + for var in var_list: + self._zeros_slot(var, "g_buffer", self._name) + + def _apply_dense(self, grad, var): + if not self._accumulation: + top_row_indices = get_top_row_indices(grad, self._density_t) + + if top_row_indices is None: + return super(DeepGradientCompressionOptimizer, self)._apply_dense(grad, var) + + sparsified_values = tf.gather(grad, top_row_indices) + sparsified_indices = top_row_indices + + sparsified_grad = tf.IndexedSlices(sparsified_values, sparsified_indices) + + return super(DeepGradientCompressionOptimizer, self)._apply_sparse_duplicate_indices( + sparsified_grad, var) + + else: + g_buffer = self.get_slot(var, "g_buffer") + + g_buffer = tf.assign_add(g_buffer, grad) + + top_row_indices = get_top_row_indices(g_buffer, self._density_t) + + if top_row_indices is None: + return super(DeepGradientCompressionOptimizer, self)._apply_dense(grad, var) + + sparsified_values = tf.gather(g_buffer, top_row_indices) + sparsified_indices = top_row_indices + + sparsified_grad = tf.IndexedSlices(sparsified_values, sparsified_indices) + + update_var = super(DeepGradientCompressionOptimizer, self)._apply_sparse_duplicate_indices( + sparsified_grad, var) + + update_g_buffer = tf.scatter_update(g_buffer, sparsified_indices, tf.zeros_like( + sparsified_values)) + + return tf.group(*[update_var, update_g_buffer]) + + def _apply_sparse_duplicate_indices(self, grad, var): + if not self._accumulation: + top_row_indices = get_top_row_indices(grad.values, self._density_t) + + if top_row_indices is None: + return super(DeepGradientCompressionOptimizer, self)._apply_sparse_duplicate_indices(grad, var) # noqa: E501 + + sparsified_values = tf.gather(grad.values, top_row_indices) + sparsified_indices = tf.gather(grad.indices, top_row_indices) + + sparsified_grad = tf.IndexedSlices(sparsified_values, sparsified_indices) + + return super(DeepGradientCompressionOptimizer, self)._apply_sparse_duplicate_indices( + sparsified_grad, var) + + else: + g_buffer = self.get_slot(var, "g_buffer") + + g_buffer = tf.scatter_update(g_buffer, grad.indices, grad.values) + + top_row_indices = get_top_row_indices(g_buffer, self._density_t) + + if top_row_indices is None: + return super(DeepGradientCompressionOptimizer, + self)._apply_sparse_duplicate_indices(grad, var) + + sparsified_values = tf.gather(g_buffer, top_row_indices) + sparsified_indices = top_row_indices + + sparsified_grad = tf.IndexedSlices(sparsified_values, sparsified_indices) + + update_var = super(DeepGradientCompressionOptimizer, self)._apply_sparse_duplicate_indices( + sparsified_grad, var) + + update_g_buffer = tf.scatter_update(g_buffer, sparsified_indices, tf.zeros_like( + sparsified_values)) + + return tf.group(*[update_var, update_g_buffer]) diff --git a/twml/twml/contrib/optimizers/pruning_optimizer.py b/twml/twml/contrib/optimizers/pruning_optimizer.py new file mode 100644 index 000000000..2bcd612ed --- /dev/null +++ b/twml/twml/contrib/optimizers/pruning_optimizer.py @@ -0,0 +1,164 @@ +""" +Provides a general optimizer for pruning features of a neural network. + +The optimizer estimates the computational cost of features, combines this information with pruning +signals indicating their usefulness, and disables features via binary masks at regular intervals. + +To make a layer prunable, use `twml.contrib.pruning.apply_mask`: + + dense1 = tf.layers.dense(inputs=inputs, units=50, activation=tf.nn.relu) + dense1 = apply_mask(dense1) + +To prune the network, apply PruningOptimizer to any cross-entropy loss: + + loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) + + optimizer = PruningOptimizer(learning_rate=0.001, momentum=0.5) + minimize = optimizer.minimize( + loss=loss, + prune_every=10, + burn_in=100, + global_step=tf.train.get_global_step()) +""" + +import tensorflow.compat.v1 as tf + +from twml.contrib.pruning import computational_cost, prune, update_pruning_signals +from twml.contrib.pruning import MASK_COLLECTION + + +class PruningOptimizer(tf.train.MomentumOptimizer): + """ + Updates parameters with SGD and pruning masks using Fisher pruning. + + Arguments: + learning_rate: float + Learning rate of SGD + + momentum: float + Momentum used by SGD + + use_locking: bool + If `True`, use locks for update operations + + name: str + Optional name prefix for the operations created when applying gradients + + use_nesterov: bool + If `True`, use Nesterov momentum + """ + + def __init__( + self, + learning_rate, + momentum=0.9, + use_locking=False, + name="PruningOptimizer", + use_nesterov=False): + super(PruningOptimizer, self).__init__( + learning_rate=learning_rate, + momentum=momentum, + use_locking=use_locking, + name=name, + use_nesterov=use_nesterov) + + def minimize( + self, + loss, + prune_every=100, + burn_in=0, + decay=.96, + flops_weight='AUTO', + flops_target=0, + update_params=None, + method='Fisher', + *args, + **kwargs): + """ + Create operations to minimize loss and to prune features. + + A pruning signal measures the importance of feature maps. This is weighed against the + computational cost of computing a feature map. Features are then iteratively pruned + based on a weighted average of feature importance S and computational cost C (in FLOPs): + + $$S + w * C$$ + + Setting `flops_weight` to 'AUTO' is the most convenient and recommended option, but not + necessarily optimal. + + Arguments: + loss: tf.Tensor + The value to minimize + + prune_every: int + One entry of a mask is set to zero only every few update steps + + burn_in: int + Pruning starts only after this many parameter updates + + decay: float + Controls exponential moving average of pruning signals + + flops_weight: float or str + Controls the targeted trade-off between computational complexity and performance + + flops_target: float + Stop pruning when computational complexity is less or this many floating point ops + + update_params: tf.Operation + Optional training operation used instead of MomentumOptimizer to update parameters + + method: str + Method used to compute pruning signal (currently only supports 'Fisher') + + Returns: + A `tf.Operation` updating parameters and pruning masks + + References: + * Theis et al., Faster gaze prediction with dense networks and Fisher pruning, 2018 + """ + + # gradient-based updates of parameters + if update_params is None: + update_params = super(PruningOptimizer, self).minimize(loss, *args, **kwargs) + + masks = tf.get_collection(MASK_COLLECTION) + + with tf.variable_scope('pruning_opt', reuse=True): + # estimate computational cost per data point + batch_size = tf.cast(tf.shape(masks[0].tensor), loss.dtype)[0] + cost = tf.divide(computational_cost(loss), batch_size, name='computational_cost') + + tf.summary.scalar('computational_cost', cost) + + if masks: + signals = update_pruning_signals(loss, masks=masks, decay=decay, method=method) + + # estimate computational cost per feature map + costs = tf.gradients(cost, masks) + + # trade off computational complexity and performance + if flops_weight.upper() == 'AUTO': + signals = [s / (c + 1e-6) for s, c in zip(signals, costs)] + elif not isinstance(flops_weight, float) or flops_weight != 0.: + signals = [s - flops_weight * c for s, c in zip(signals, costs)] + + counter = tf.Variable(0, name='pruning_counter') + counter = tf.assign_add(counter, 1, use_locking=True) + + # only prune every so often after a burn-in phase + pruning_cond = tf.logical_and(counter > burn_in, tf.equal(counter % prune_every, 0)) + + # stop pruning after reaching threshold + if flops_target > 0: + pruning_cond = tf.logical_and(pruning_cond, tf.greater(cost, flops_target)) + + update_masks = tf.cond( + pruning_cond, + lambda: prune(signals, masks=masks), + lambda: tf.group(masks)) + + return tf.group([update_params, update_masks]) + + # no masks found + return update_params diff --git a/twml/twml/contrib/parsers.py b/twml/twml/contrib/parsers.py new file mode 100644 index 000000000..a27f2acbd --- /dev/null +++ b/twml/twml/contrib/parsers.py @@ -0,0 +1,21 @@ +''' +Contains implementations of functions to parse the contrib.FeatureConfig + +Modelers can use the functions in this module as the the train/eval_parse_fn of +the DataRecordTrainer constructor to customize how to parse their datasets. + +Modelers may also provide custom implementations of train/eval_parse_fn using these as reference. +''' + +from twitter.deepbird.io.legacy.contrib.parsers import ( + _convert_to_fixed_length_tensor, # noqa: F401 + _get_input_receiver_fn_feature_dict, # noqa: F401 + _merge_dictionaries, # noqa: F401 + get_features_as_tensor_dict, # noqa: F401 + get_keras_parse_fn, # noqa: F401 + get_serving_input_receiver_fn_feature_dict, # noqa: F401 + get_string_tensor_parse_fn, # noqa: F401 + get_string_tensor_serving_input_receiver_fn, # noqa: F401 + get_supervised_input_receiver_fn_feature_dict, # noqa: F401 + parse_string_tensor, # noqa: F401 +) diff --git a/twml/twml/contrib/pruning.py b/twml/twml/contrib/pruning.py new file mode 100644 index 000000000..b6ddee693 --- /dev/null +++ b/twml/twml/contrib/pruning.py @@ -0,0 +1,363 @@ +""" +This module implements tools for pruning neural networks. + +In particular, it provides tools for dealing with masks: + + features = apply_mask(features) + +The function `apply_mask` applies a binary mask to the channels of a given tensor. Consider the +following loss: + + logits = tf.matmul(features, weights) + loss = tf.losses.sparse_softmax_cross_entropy(labels, logits) + +Each mask has a corresponding pruning signal. The function `update_pruning_signals` will update and +return these signals: + + signals = update_pruning_signals(loss) + +The pruning operation will zero out the mask entry with the smallest corresponding pruning signal: + + prune(signals) + +The following function allows us to estimate the computational cost of a graph (number of FLOPs): + + cost = computational_cost(loss) + +To compute the cost of each feature per data point, we can do: + + costs = tf.gradients(cost / batch_size, masks) + +The current implementation of `computational_cost` is designed to work with standard feed-forward +and convolutional network architectures only, but may fail with more complicated architectures. +""" + + +import numpy as np +import tensorflow.compat.v1 as tf + +MASK_COLLECTION = 'pruning/masks' +MASK_EXTENDED_COLLECTION = 'pruning/masks_extended' +OP_COLLECTION = 'pruning/ops' + + +def apply_mask(tensor, name='pruning'): + """ + Point-wise multiplies a tensor with a binary mask. + + During training, pruning is simulated by setting entries of the mask to zero. + + Arguments: + tensor: tf.Tensor + A tensor where the last dimension represents channels which will be masked + + Returns: + `tf.Tensor` with same shape as `tensor` + """ + + tensor_shape = tensor.shape + + with tf.variable_scope(name, reuse=True): + # allocate masks and corresponding pruning signals + mask = tf.Variable(tf.ones(tensor.shape.as_list()[-1]), trainable=False, name='mask') + pruning_signal = tf.Variable(tf.zeros_like(mask), trainable=False, name='signal') + + # extending masks is a trick to get a separate gradient for each data point + mask_extended = extend_mask(mask, tensor) + + # store extended mask, pruning signal, and other vars for easy access later + mask.extended = mask_extended + mask.pruning_signal = pruning_signal + mask.tensor = tensor + + # mask tensor + tensor = tf.multiply(tensor, mask_extended) + tensor.set_shape(tensor_shape) + tensor._mask = mask + + tf.add_to_collection(MASK_COLLECTION, mask) + tf.add_to_collection(MASK_EXTENDED_COLLECTION, mask.extended) + tf.add_to_collection(OP_COLLECTION, tensor.op) + + return tensor + + +def extend_mask(mask, tensor): + """ + Repeats the mask for each data point stored in a tensor. + + If `tensor` is AxBxC dimensional and `mask` is C dimensional, returns an Ax1xC dimensional + tensor with A copies or `mask`. + + Arguments: + mask: tf.Tensor + The mask which will be extended + + tensor: tf.Tensor + The tensor to which the extended mask will be applied + + Returns: + The extended mask + """ + + batch_size = tf.shape(tensor)[:1] + ones = tf.ones([tf.rank(tensor) - 1], dtype=batch_size.dtype) + multiples = tf.concat([batch_size, ones], 0) + mask_shape = tf.concat([ones, [-1]], 0) + return tf.tile(tf.reshape(mask, mask_shape), multiples) + + +def find_input_mask(tensor): + """ + Find ancestral mask affecting the number of pruned channels of a tensor. + + Arguments: + tensor: tf.Tensor + Tensor for which to identify relevant mask + + Returns: + A `tf.Tensor` or `None` + """ + + if hasattr(tensor, '_mask'): + return tensor._mask + if tensor.op.type in ['MatMul', 'Conv1D', 'Conv2D', 'Conv3D', 'Transpose']: + # op produces a new number of channels, preceding mask therefore irrelevant + return None + if not tensor.op.inputs: + return None + for input in tensor.op.inputs: + mask = find_input_mask(input) + if mask is not None: + return mask + + +def find_output_mask(tensor): + """ + Find mask applied to the tensor or one of its descendants if it affects the tensor's pruned shape. + + Arguments: + tensor: tf.Tensor or tf.Variable + Tensor for which to identify relevant mask + + Returns: + A `tf.Tensor` or `None` + """ + + if isinstance(tensor, tf.Variable): + return find_output_mask(tensor.op.outputs[0]) + if hasattr(tensor, '_mask'): + return tensor._mask + for op in tensor.consumers(): + if len(op.outputs) != 1: + continue + if op.type in ['MatMul', 'Conv1D', 'Conv2D', 'Conv3D']: + # masks of descendants are only relevant if tensor is right-multiplied + if tensor == op.inputs[1]: + return find_output_mask(op.outputs[0]) + return None + mask = find_output_mask(op.outputs[0]) + if mask is not None: + return mask + + +def find_mask(tensor): + """ + Returns masks indicating channels of the tensor that are effectively removed from the graph. + + Arguments: + tensor: tf.Tensor + Tensor for which to compute a mask + + Returns: + A `tf.Tensor` with binary entries indicating disabled channels + """ + + input_mask = find_input_mask(tensor) + output_mask = find_output_mask(tensor) + if input_mask is None: + return output_mask + if output_mask is None: + return input_mask + if input_mask is output_mask: + return input_mask + return input_mask * output_mask + + +def pruned_shape(tensor): + """ + Computes the shape of a tensor after taking into account pruning of channels. + + Note that the shape will only differ in the last dimension, even if other dimensions are also + effectively disabled by pruning masks. + + Arguments: + tensor: tf.Tensor + Tensor for which to compute a pruned shape + + Returns: + A `tf.Tensor[tf.float32]` representing the pruned shape + """ + + mask = find_mask(tensor) + + if mask is None: + return tf.cast(tf.shape(tensor), tf.float32) + + return tf.concat([ + tf.cast(tf.shape(tensor)[:-1], mask.dtype), + tf.reduce_sum(mask, keepdims=True)], 0) + + +def computational_cost(op_or_tensor, _observed=None): + """ + Estimates the computational complexity of a pruned graph (number of floating point operations). + + This function currently only supports sequential graphs such as those of MLPs and + simple CNNs with 2D convolutions in NHWC format. + + Note that the computational cost returned by this function is proportional to batch size. + + Arguments: + op_or_tensor: tf.Tensor or tf.Operation + Root node of graph for which to compute computational cost + + Returns: + A `tf.Tensor` representing a number of floating point operations + """ + + cost = tf.constant(0.) + + # exclude cost of computing extended pruning masks + masks_extended = [mask.extended for mask in tf.get_collection(MASK_COLLECTION)] + if op_or_tensor in masks_extended: + return cost + + # convert tensor to op + op = op_or_tensor.op if isinstance(op_or_tensor, (tf.Tensor, tf.Variable)) else op_or_tensor + + # make sure cost of op will not be counted twice + if _observed is None: + _observed = [] + elif op in _observed: + return cost + _observed.append(op) + + # compute cost of computing inputs + for tensor in op.inputs: + cost = cost + computational_cost(tensor, _observed) + + # add cost of operation + if op.op_def is None or op in tf.get_collection(OP_COLLECTION): + # exclude cost of undefined ops and pruning ops + return cost + + elif op.op_def.name == 'MatMul': + shape_a = pruned_shape(op.inputs[0]) + shape_b = pruned_shape(op.inputs[1]) + return cost + shape_a[0] * shape_b[1] * (2. * shape_a[1] - 1.) + + elif op.op_def.name in ['Add', 'Mul', 'BiasAdd']: + return cost + tf.cond( + tf.size(op.inputs[0]) > tf.size(op.inputs[1]), + lambda: tf.reduce_prod(pruned_shape(op.inputs[0])), + lambda: tf.reduce_prod(pruned_shape(op.inputs[1]))) + + elif op.op_def.name in ['Conv2D']: + output_shape = pruned_shape(op.outputs[0]) + input_shape = pruned_shape(op.inputs[0]) + kernel_shape = pruned_shape(op.inputs[1]) + inner_prod_cost = (tf.reduce_prod(kernel_shape[:2]) * input_shape[-1] * 2. - 1.) + return cost + tf.reduce_prod(output_shape) * inner_prod_cost + + return cost + + +def update_pruning_signals(loss, decay=.96, masks=None, method='Fisher'): + """ + For each mask, computes corresponding pruning signals indicating the importance of a feature. + + Arguments: + loss: tf.Tensor + Any cross-entropy loss + + decay: float + Controls exponential moving average of pruning signals + + method: str + Method used to compute pruning signal (currently only supports 'Fisher') + + Returns: + A `list[tf.Tensor]` of pruning signals corresponding to masks + + References: + * Theis et al., Faster gaze prediction with dense networks and Fisher pruning, 2018 + """ + + if masks is None: + masks = tf.get_collection(MASK_COLLECTION) + + if method not in ['Fisher']: + raise ValueError('Pruning method \'{0}\' not supported.'.format(method)) + + if not masks: + return [] + + with tf.variable_scope('pruning_opt', reuse=True): + # compute gradients of extended masks (yields separate gradient for each data point) + grads = tf.gradients(loss, [m.extended for m in masks]) + + # estimate Fisher pruning signals from batch + signals_batch = [tf.squeeze(tf.reduce_mean(tf.square(g), 0)) for g in grads] + + # update pruning signals + signals = [m.pruning_signal for m in masks] + signals = [tf.assign(s, decay * s + (1. - decay) * f, use_locking=True) + for s, f in zip(signals, signals_batch)] + + return signals + + +def prune(signals, masks=None): + """ + Prunes a single feature by zeroing the mask entry with the smallest pruning signal. + + Arguments: + signals: list[tf.Tensor] + A list of pruning signals + + masks: list[tf.Tensor] + A list of corresponding masks, defaults to `tf.get_collection(MASK_COLLECTION)` + + Returns: + A `tf.Operation` which updates masks + """ + + if masks is None: + masks = tf.get_collection(MASK_COLLECTION) + + with tf.variable_scope('pruning_opt', reuse=True): + # make sure we don't select already pruned units + signals = [tf.where(m > .5, s, tf.zeros_like(s) + np.inf) for m, s in zip(masks, signals)] + + # find units with smallest pruning signal in each layer + min_idx = [tf.argmin(s) for s in signals] + min_signals = [s[i] for s, i in zip(signals, min_idx)] + + # find layer with smallest pruning signal + l = tf.argmin(min_signals) + + # construct pruning operations, one for each mask + updates = [] + for k, i in enumerate(min_idx): + # set mask of layer l to 0 where pruning signal is smallest + updates.append( + tf.cond( + tf.equal(l, k), + lambda: tf.scatter_update( + masks[k], tf.Print(i, [i], message="Pruning layer [{0}] at index ".format(k)), 0.), + lambda: masks[k])) + + updates = tf.group(updates, name='prune') + + return updates diff --git a/twml/twml/contrib/readers/__init__.py b/twml/twml/contrib/readers/__init__.py new file mode 100644 index 000000000..e96cf0449 --- /dev/null +++ b/twml/twml/contrib/readers/__init__.py @@ -0,0 +1,5 @@ +# pylint: disable=wildcard-import +"""This module contains experimental readers classes""" +from .batch_prediction_request import BatchPredictionRequest # noqa: F401 +from .data_record import DataRecord # noqa: F401 +from .hashed_batch_prediction_request import HashedBatchPredictionRequest # noqa: F401 diff --git a/twml/twml/contrib/readers/batch_prediction_request.py b/twml/twml/contrib/readers/batch_prediction_request.py new file mode 100644 index 000000000..4408b33b4 --- /dev/null +++ b/twml/twml/contrib/readers/batch_prediction_request.py @@ -0,0 +1,8 @@ +# pylint: disable=invalid-name +""" +This module implements the reader for BatchPredictionRequest. +""" + +from twitter.deepbird.io.legacy.contrib.readers.batch_prediction_request import ( + BatchPredictionRequest # noqa: F401 +) diff --git a/twml/twml/contrib/readers/data_record.py b/twml/twml/contrib/readers/data_record.py new file mode 100644 index 000000000..ae8cc0b68 --- /dev/null +++ b/twml/twml/contrib/readers/data_record.py @@ -0,0 +1,10 @@ +""" +This module includes facilities for manipulating data records in DeepBird v2. +This contains a submodule that allows for easy feature access as Tensors. +The result of this subclass methods are dictionaries of Tensors and SparseTensors +""" + +from twitter.deepbird.io.legacy.contrib.readers.data_record import ( + SUPPORTED_DENSE_FEATURE_TYPES, # noqa: F401 + DataRecord, # noqa: F401 +) diff --git a/twml/twml/contrib/readers/hashed_batch_prediction_request.py b/twml/twml/contrib/readers/hashed_batch_prediction_request.py new file mode 100644 index 000000000..3454f8483 --- /dev/null +++ b/twml/twml/contrib/readers/hashed_batch_prediction_request.py @@ -0,0 +1,8 @@ +# pylint: disable=invalid-name +""" +This module implements the reader for HashedBatchPredictionRequest. +""" + +from twitter.deepbird.io.legacy.contrib.readers.hashed_batch_prediction_request import ( + HashedBatchPredictionRequest # noqa: F401 +) diff --git a/twml/twml/contrib/trainers/__init__.py b/twml/twml/contrib/trainers/__init__.py new file mode 100644 index 000000000..3226cd805 --- /dev/null +++ b/twml/twml/contrib/trainers/__init__.py @@ -0,0 +1,5 @@ +# pylint: disable=wildcard-import +"""This module contains experimental trainer classes""" +from .batch_prediction_request_trainer import BatchPredictionRequestTrainer # noqa: F401 +from .pruning_data_record_trainer import PruningDataRecordTrainer # noqa: F401 +from .trainer_utils import build_keras_trainer # noqa: F401 diff --git a/twml/twml/contrib/trainers/batch_prediction_request_trainer.py b/twml/twml/contrib/trainers/batch_prediction_request_trainer.py new file mode 100644 index 000000000..2effa87ed --- /dev/null +++ b/twml/twml/contrib/trainers/batch_prediction_request_trainer.py @@ -0,0 +1,180 @@ +# pylint: disable=arguments-differ, invalid-name +""" +This file contains the DataRecordTrainer class. +""" +import warnings + +import twml +from twml.trainers import DataRecordTrainer + + +class BatchPredictionRequestTrainer(DataRecordTrainer): # pylint: disable=abstract-method + """ + The ``BatchPredictionRequestTrainer`` implementation is intended to satisfy use cases + that input is BatchPredictionRequest at Twitter and also where only the build_graph methods + needs to be overridden. For this reason, ``Trainer.[train,eval]_input_fn`` methods + assume a DataRecord dataset partitioned into part files stored in compressed (e.g. gzip) format. + + For use-cases that differ from this common Twitter use-case, + further Trainer methods can be overridden. + If that still doesn't provide enough flexibility, the user can always + use the tf.estimator.Esimator or tf.session.run directly. + """ + + def __init__( + self, name, params, + build_graph_fn, + feature_config=None, + **kwargs): + """ + The BatchPredictionRequestTrainer constructor builds a + ``tf.estimator.Estimator`` and stores it in self.estimator. + For this reason, BatchPredictionRequestTrainer accepts the same Estimator constructor arguments. + It also accepts additional arguments to facilitate metric evaluation and multi-phase training + (init_from_dir, init_map). + + Args: + parent arguments: + See the `Trainer constructor <#twml.trainers.Trainer.__init__>`_ documentation + for a full list of arguments accepted by the parent class. + name, params, build_graph_fn (and other parent class args): + see documentation for twml.Trainer and twml.DataRecordTrainer doc. + feature_config: + An object of type FeatureConfig describing what features to decode. + Defaults to None. But it is needed in the following cases: + - `get_train_input_fn()` / `get_eval_input_fn()` is called without a `parse_fn` + - `learn()`, `train()`, `eval()`, `calibrate()` are called without providing `*input_fn`. + + **kwargs: + further kwargs can be specified and passed to the Estimator constructor. + """ + + # Check and update train_batch_size and eval_batch_size in params before initialization + # to print correct parameter logs and does not stop running + # This overwrites batch_size parameter constrains in twml.trainers.Trainer.check_params + updated_params = self.check_batch_size_params(params) + super(BatchPredictionRequestTrainer, self).__init__( + name=name, params=updated_params, build_graph_fn=build_graph_fn, **kwargs) + + def check_batch_size_params(self, params): + """ Verify that params has the correct key,values """ + # updated_params is an instance of tensorflow.contrib.training.HParams + updated_params = twml.util.convert_to_hparams(params) + param_values = updated_params.values() + + # twml.trainers.Trainer.check_params already checks other constraints, + # such as being an integer + if 'train_batch_size' in param_values: + if not isinstance(updated_params.train_batch_size, int): + raise ValueError("Expecting params.train_batch_size to be an integer.") + if param_values['train_batch_size'] != 1: + # This can be a bit annoying to force users to pass the batch sizes, + # but it is good to let them know what they actually use in the models + # Use warning instead of ValueError in there to continue the run + # and print out that train_batch_size is changed + warnings.warn('You are processing BatchPredictionRequest data, ' + 'train_batch_size is always 1.\n' + 'The number of DataRecords in a batch is determined by the size ' + 'of each BatchPredictionRequest.\n' + 'If you did not pass train.batch_size or eval.batch_size, and ' + 'the default batch_size 32 was in use,\n' + 'please pass --train.batch_size 1 --eval.batch_size 1') + # If the upper error warning, change/pass --train.batch_size 1 + # so that train_batch_size = 1 + updated_params.train_batch_size = 1 + + if 'eval_batch_size' in param_values: + if not isinstance(updated_params.train_batch_size, int): + raise ValueError('Expecting params.eval_batch_size to be an integer.') + if param_values['eval_batch_size'] != 1: + # This can be a bit annoying to force users to pass the batch sizes, + # but it is good to let them know what they actually use in the models + # Use warning instead of ValueError in there to continue the run + # and print out that eval_batch_size is changed + warnings.warn('You are processing BatchPredictionRequest data, ' + 'eval_batch_size is also always 1.\n' + 'The number of DataRecords in a batch is determined by the size ' + 'of each BatchPredictionRequest.\n' + 'If you did not pass train.batch_size or eval.batch_size, and ' + 'the default batch_size 32 was in use,\n' + 'please pass --train.batch_size 1 --eval.batch_size 1') + # If the upper warning raises, change/pass --eval.batch_size 1 + # so that eval_batch_size = 1 + updated_params.eval_batch_size = 1 + + if 'eval_batch_size' not in param_values: + updated_params.eval_batch_size = 1 + + if not updated_params.eval_batch_size: + updated_params.eval_batch_size = 1 + + return updated_params + + @staticmethod + def add_batch_prediction_request_arguments(): + """ + Add commandline args to parse typically for the BatchPredictionRequestTrainer class. + Typically, the user calls this function and then parses cmd-line arguments + into an argparse.Namespace object which is then passed to the Trainer constructor + via the params argument. + + See the `code <_modules/twml/argument_parser.html#get_trainer_parser>`_ + for a list and description of all cmd-line arguments. + + Returns: + argparse.ArgumentParser instance with some useful args already added. + """ + parser = super(BatchPredictionRequestTrainer, + BatchPredictionRequestTrainer).add_parser_arguments() + + # mlp arguments + parser.add_argument( + '--model.use_existing_discretizer', action='store_true', + dest="model_use_existing_discretizer", + help='Load a pre-trained calibration or train a new one') + parser.add_argument( + '--model.use_binary_values', action='store_true', + dest='model_use_binary_values', + help='Use the use_binary_values optimization') + + # control hom many featues we keep in sparse tensors + # 12 is enough for learning-to-rank for now + parser.add_argument( + '--input_size_bits', type=int, default=12, + help='Number of bits allocated to the input size') + + parser.add_argument( + '--loss_function', type=str, default='ranknet', + dest='loss_function', + help='Options are pairwise: ranknet (default), lambdarank, ' + 'listnet, listmle, attrank, ' + 'pointwise') + + # whether convert sparse tensors to dense tensor + # in order to use dense normalization methods + parser.add_argument( + '--use_dense_tensor', action='store_true', + dest='use_dense_tensor', + default=False, + help='If use_dense_tensor is False, ' + 'sparse tensor and spare normalization are in use. ' + 'If use_dense_tensor is True, ' + 'dense tensor and dense normalization are in use.') + + parser.add_argument( + '--dense_normalization', type=str, default='mean_max_normalizaiton', + dest='dense_normalization', + help='Options are mean_max_normalizaiton (default), standard_normalizaiton') + + parser.add_argument( + '--sparse_normalization', type=str, default='SparseMaxNorm', + dest='sparse_normalization', + help='Options are SparseMaxNorm (default), SparseBatchNorm') + + # so far only used in pairwise learning-to-rank + parser.add_argument( + '--mask', type=str, default='full_mask', + dest='mask', + help='Options are full_mask (default), diag_mask') + + return parser diff --git a/twml/twml/contrib/trainers/pruning_data_record_trainer.py b/twml/twml/contrib/trainers/pruning_data_record_trainer.py new file mode 100644 index 000000000..4796e5390 --- /dev/null +++ b/twml/twml/contrib/trainers/pruning_data_record_trainer.py @@ -0,0 +1,59 @@ +import tensorflow.compat.v1 as tf + +from twml.trainers import DataRecordTrainer +from twml.contrib.optimizers import PruningOptimizer + + +class PruningDataRecordTrainer(DataRecordTrainer): + @staticmethod + def get_train_op(params, loss): + train_op = DataRecordTrainer.get_train_op(params, loss) + + optimizer = PruningOptimizer(learning_rate=params.get('learning_rate')) + + return optimizer.minimize( + loss=loss, + prune_every=params.get('pruning_iter', 5000), + burn_in=params.get('pruning_burn_in', 100000), + decay=params.get('pruning_decay', .9999), + flops_target=params.get('pruning_flops_target', 250000), + update_params=train_op, + global_step=tf.train.get_global_step()) + + def __init__(self, name, params, build_graph_fn, feature_config=None, **kwargs): + kwargs['optimize_loss_fn'] = self.get_train_op + + super(PruningDataRecordTrainer, self).__init__( + name=name, + params=params, + build_graph_fn=build_graph_fn, + feature_config=feature_config, + **kwargs) + + def export_model(self, *args, **kwargs): + # TODO: modify graph before exporting to take into account masks + return super(PruningDataRecordTrainer, self).export_model(*args, **kwargs) + + @staticmethod + def add_parser_arguments(): + parser = DataRecordTrainer.add_parser_arguments() + parser.add_argument( + "--pruning.iter", "--pruning_iter", type=int, default=5000, + dest="pruning_iter", + help="A single feature or feature map is pruned every this many iterations") + parser.add_argument( + "--pruning.burn_in", "--pruning_burn_in", type=int, default=100000, + dest="pruning_burn_in", + help="Only start pruning after collecting statistics for this many training steps") + parser.add_argument( + "--pruning.flops_target", "--pruning_flops_target", type=int, default=250000, + dest="pruning_flops_target", + help="Stop pruning when estimated number of floating point operations reached this target. \ + For example, a small feed-forward network might require 250,000 FLOPs to run.") + parser.add_argument( + "--pruning.decay", "--pruning_decay", type=float, default=.9999, + dest="pruning_decay", + help="A float value in [0.0, 1.0) controlling an exponential moving average of pruning \ + signal statistics. A value of 0.9999 can be thought of as averaging statistics over 10,000 \ + steps.") + return parser diff --git a/twml/twml/contrib/trainers/trainer_utils.py b/twml/twml/contrib/trainers/trainer_utils.py new file mode 100644 index 000000000..f279571be --- /dev/null +++ b/twml/twml/contrib/trainers/trainer_utils.py @@ -0,0 +1,111 @@ +""" +This is a temporary close gap solution that allows TensorFlow users to do exploration and +experimentation using Keras models, and production training using twml Trainer. + +As of now (Q4 2019), Keras model training using `model.fit()` has various issues, making it unfit +for production training: + 1. `model.fit()` is slow in TF 1.14. This will be fixed with future TensorFlow updates. + 2. `model.fit()` crashes during model saving or in eager mode when the input has SparseTensor. + 3. Models saved using TF 2.0 API cannot be served by TensorFlow's Java API. + +Until MLCE team resolves the above issues, MLCE team recommends the following: + - Please feel free to use Keras models for experimentation and exploration. + - Please stick to twml Trainer for production training & exporting, + especially if you want to serve your model using Twitter's prediction servers. + +This module provide tooling for easily training keras models using twml Trainer. + +This module takes a Keras model that performs binary classification, and returns a +`twml.trainers.Trainer` object performing the same task. +The common way to use the returned Trainer object is to call its +`train`, `evaluate`, `learn`, or `train_and_evaluate` method with an input function. +This input function can be created from the tf.data.Dataset you used with your Keras model. + +.. note: this util handles the most common case. If you have cases not satisfied by this util, + consider writing your own build_graph to wrap your keras models. +""" +from twitter.deepbird.hparam import HParams + +import tensorflow # noqa: F401 +import tensorflow.compat.v2 as tf + +import twml + + +def build_keras_trainer( + name, + model_factory, + save_dir, + loss_fn=None, + metrics_fn=None, + **kwargs): + """ + Compile the given model_factory into a twml Trainer. + + Args: + name: a string name for the returned twml Trainer. + + model_factory: a callable that returns a keras model when called. + This keras model is expected to solve a binary classification problem. + This keras model takes a dict of tensors as input, and outputs a logit or probability. + + save_dir: a directory where the trainer saves data. Can be an HDFS path. + + loss_fn: the loss function to use. Defaults to tf.keras.losses.BinaryCrossentropy. + + metrics_fn: metrics function used by TensorFlow estimators. + Defaults to twml.metrics.get_binary_class_metric_fn(). + + **kwargs: for people familiar with twml Trainer's options, they can be passed in here + as kwargs, and they will be forwarded to Trainer as opts. + See https://cgit.twitter.biz/source/tree/twml/twml/argument_parser.py#n43 for available args. + + Returns: + a twml.trainers.Trainer object which can be used for training and exporting models. + """ + build_graph = create_build_graph_fn(model_factory, loss_fn) + + if metrics_fn is None: + metrics_fn = twml.metrics.get_binary_class_metric_fn() + + opts = HParams(**kwargs) + opts.add_hparam('save_dir', save_dir) + + return twml.trainers.Trainer( + name, + opts, + build_graph_fn=build_graph, + save_dir=save_dir, + metric_fn=metrics_fn) + + +def create_build_graph_fn(model_factory, loss_fn=None): + """Create a build graph function from the given keras model.""" + + def build_graph(features, label, mode, params, config=None): + # create model from model factory. + model = model_factory() + + # create loss function if the user didn't specify one. + if loss_fn is None: + build_graph_loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=False) + else: + build_graph_loss_fn = loss_fn + + output = model(features) + if mode == 'infer': + loss = None + else: + weights = features.get('weights', None) + loss = build_graph_loss_fn(y_true=label, y_pred=output, sample_weight=weights) + + if isinstance(output, dict): + if loss is None: + return output + else: + output['loss'] = loss + return output + else: + return {'output': output, 'loss': loss} + + return build_graph diff --git a/twml/twml/contrib/utils/__init__.py b/twml/twml/contrib/utils/__init__.py new file mode 100644 index 000000000..56a083048 --- /dev/null +++ b/twml/twml/contrib/utils/__init__.py @@ -0,0 +1,18 @@ +# pylint: disable=wildcard-import +"""This module contains experimental util functions for contrib.""" + +from .math_fns import safe_div, safe_log, cal_ndcg, cal_swapped_ndcg # noqa: F401 +from .masks import diag_mask, full_mask # noqa: F401 +from .normalizer import mean_max_normalizaiton, standard_normalizaiton # noqa: F401 +from .scores import get_pairwise_scores, get_pairwise_label_scores # noqa: F401 +# pointwise functions +from .loss_fns import get_pointwise_loss # noqa: F401 +# ranknet functions +from .loss_fns import get_pair_loss # noqa: F401 +# listwise functions +from .loss_fns import get_attrank_loss, get_listnet_loss, get_listmle_loss # noqa: F401 +# lambdarank functions +from .loss_fns import get_lambda_pair_loss # noqa: F401 +from .device import get_device_map, get_gpu_list, get_gpu_count, is_gpu_available # noqa: F401 +from .similarities import cosine_similarity # noqa: F401 +from . import interp # noqa: F401 diff --git a/twml/twml/contrib/utils/datasets.py b/twml/twml/contrib/utils/datasets.py new file mode 100644 index 000000000..d31ea3ae4 --- /dev/null +++ b/twml/twml/contrib/utils/datasets.py @@ -0,0 +1,93 @@ +import random + +import twml + +get_time_based_dataset_files = twml.util.list_files_by_datetime + + +def resolve_train_and_eval_files_overlap( + train_files, eval_files, fraction_kept_for_eval, seed=None +): + """Resolve any overlap between train and eval files. + + Specifically, if there's an overlap between `train_files` and `eval_files`, then a fraction of + the overlap (i.e. `fraction_kept_for_eval`) will be randomly assigned (exclusively) to the + `eval_files`. + + The following example demonstrates its usage: + + >>> orig_train_files = ['f1', 'f2', 'f3', 'f4'] + >>> orig_eval_files = ['f1', 'f2', 'f3'] + >>> resolved_train_files, resolved_eval_files = resolve_train_and_eval_files_overlap( + ... orig_train_files, orig_eval_files, 0.5 + ... ) + >>> set(resolved_train_files) & set(resolved_eval_files) == set() + True + >>> len(resolved_train_files) == 3 + True + >>> len(resolved_eval_files) == 2 + True + + Args: + train_files: A list of the files used for training. + eval_files: A list of the files used for validation. + fraction_kept_for_eval: A fraction of files in the intersection between `train_files` and + `eval_files` exclusively kept for evaluation. + seed: A seed for generating random numbers. + + Returns: + A tuple `(new_train_files, new_eval_files)` with the overlapping resolved. + """ + + rng = random.Random(seed) + + train_files = set(train_files) + eval_files = set(eval_files) + overlapping_files = train_files & eval_files + train_files_selected_for_eval = set(rng.sample( + overlapping_files, + int(len(overlapping_files) * fraction_kept_for_eval) + )) + train_files = train_files - train_files_selected_for_eval + eval_files = (eval_files - overlapping_files) | train_files_selected_for_eval + return list(train_files), list(eval_files) + + +def get_time_based_dataset_files_for_train_and_eval( + base_path, + train_start_datetime, + train_end_datetime, + eval_start_datetime, + eval_end_datetime, + fraction_kept_for_eval, + datetime_prefix_format='%Y/%m/%d/%H', + extension='lzo', + parallelism=1 +): + """Get train/eval dataset files organized with a time-based prefix. + + This is just a convenience built around `get_dataset_files_prefixed_by_time` and + `resolve_train_and_eval_files_overlap`. Please refer to these functions for documentation. + """ + + train_files = get_time_based_dataset_files( + base_path=base_path, + start_datetime=train_start_datetime, + end_datetime=train_end_datetime, + datetime_prefix_format=datetime_prefix_format, + extension=extension, + parallelism=parallelism + ) + eval_files = get_time_based_dataset_files( + base_path=base_path, + start_datetime=eval_start_datetime, + end_datetime=eval_end_datetime, + datetime_prefix_format=datetime_prefix_format, + extension=extension, + parallelism=parallelism + ) + return resolve_train_and_eval_files_overlap( + train_files=train_files, + eval_files=eval_files, + fraction_kept_for_eval=fraction_kept_for_eval + ) diff --git a/twml/twml/contrib/utils/device.py b/twml/twml/contrib/utils/device.py new file mode 100644 index 000000000..ab189c98a --- /dev/null +++ b/twml/twml/contrib/utils/device.py @@ -0,0 +1,27 @@ +""" +Functions to query devices being used by tensorflow +""" + +from tensorflow.python.client import device_lib + + +def get_device_map(): + """Returns the map of device name to device type""" + local_device_protos = device_lib.list_local_devices() + return {x.name: x.device_type for x in local_device_protos} + + +def get_gpu_list(): + """Returns the list of GPUs available""" + device_map = get_device_map() + return [name for name in device_map if device_map[name] == 'GPU'] + + +def get_gpu_count(): + """Returns the count of GPUs available""" + return len(get_gpu_list()) + + +def is_gpu_available(): + """Returns if GPUs are available""" + return get_gpu_count() > 0 diff --git a/twml/twml/contrib/utils/interp.py b/twml/twml/contrib/utils/interp.py new file mode 100644 index 000000000..419d89030 --- /dev/null +++ b/twml/twml/contrib/utils/interp.py @@ -0,0 +1,94 @@ +""" +Interpolation functions +""" + +import libtwml +import tensorflow.compat.v1 as tf +import twml + + +def linear_interp1(inputs, ref_inputs, ref_outputs): + """ + Perform 1D linear interpolation. + Arguments: + inputs: + The query input values. + ref_inputs: + Reference grid points used for interpolation. + ref_outputs: + Reference output values used for interpolation. + + Returns: + The interpolated outputs for the requested input values. + """ + + inputs = tf.convert_to_tensor(inputs) + ref_inputs = tf.convert_to_tensor(ref_inputs) + ref_outputs = tf.convert_to_tensor(ref_outputs) + + ndims = inputs.shape.ndims + ref_inputs_ndims = ref_inputs.shape.ndims + ref_outputs_ndims = ref_inputs.shape.ndims + + if (ref_inputs_ndims != ndims): + raise ValueError("Dimension mismatch. inputs: %d, ref_inputs: %d" % (ndims, ref_inputs_ndims)) + + if (ref_outputs_ndims != ndims): + raise ValueError("Dimension mismatch. inputs: %d, ref_outputs: %d" % (ndims, ref_outputs_ndims)) + + if ndims > 2: + raise ValueError("Input dimensions should be < 2D. But got %d." % ndims) + + original_input_shape = tf.shape(inputs) + # This is needed because isotonic_calibration expects: + # - inputs of size [num_samples, num_classes] + # - ref_inputs, ref_outputs of size [num_classes, num_bins] + inputs = tf.reshape(inputs, [-1, 1]) + ref_inputs = tf.reshape(ref_inputs, [1, -1]) + ref_outputs = tf.reshape(ref_outputs, [1, -1]) + + # isotonic_calibration is simply doing linear interpolation. + # This needs to be renamed in the future to make it consistent. + outputs = libtwml.ops.isotonic_calibration(inputs, ref_inputs, ref_outputs) + return tf.reshape(outputs, original_input_shape) + + +def linear_interp1_by_class(inputs, input_classes, ref_inputs, ref_outputs): + """ + Perform 1D linear interpolation. + Arguments: + inputs: + The query input values. + input_classes: + The class index to use from the reference grid. + ref_inputs: + Reference 2D grid points used for interpolation. + Each row denotes the grid from a different class. + ref_outputs: + Reference 2D output values used for interpolation. + Each row denotes the grid from a different class. + + Returns: + The interpolated outputs for the requested input values. + """ + + inputs = tf.convert_to_tensor(inputs) + input_classes = tf.convert_to_tensor(input_classes) + ref_inputs = tf.convert_to_tensor(ref_inputs) + ref_outputs = tf.convert_to_tensor(ref_outputs) + + original_input_shape = tf.shape(inputs) + + # pass through + def in_func(x): + return x + + # indexed function + def cond_func(i, fn): + idx = input_classes[i] + x = tf.expand_dims(fn(), axis=0) + return linear_interp1(x, ref_inputs[idx], ref_outputs[idx]) + + # Use while loop for now, needs to be replace by a custom C++ op later. + outputs = twml.util.batch_apply(in_func, inputs, cond_func=cond_func) + return tf.reshape(outputs, original_input_shape) diff --git a/twml/twml/contrib/utils/loss_fns.py b/twml/twml/contrib/utils/loss_fns.py new file mode 100644 index 000000000..eb25b430a --- /dev/null +++ b/twml/twml/contrib/utils/loss_fns.py @@ -0,0 +1,302 @@ +import tensorflow.compat.v1 as tf +from twml.contrib.utils import masks, math_fns + + +def get_pair_loss(pairwise_label_scores, pairwise_predicted_scores, + params): + """ + Paiwise learning-to-rank ranknet loss + Check paper https://www.microsoft.com/en-us/research/publication/ + learning-to-rank-using-gradient-descent/ + for more information + Args: + pairwise_label_scores: a dense tensor of shape [n_data, n_data] + pairwise_predicted_scores: a dense tensor of shape [n_data, n_data] + n_data is the number of tweet candidates in a BatchPredictionRequest + params: network parameters + mask options: full_mask and diag_mask + Returns: + average loss over pairs defined by the masks + """ + n_data = tf.shape(pairwise_label_scores)[0] + if params.mask == "full_mask": + # full_mask that only covers pairs that have different labels + # (all pairwise_label_scores = 0.5: selfs and same labels are 0s) + mask, pair_count = masks.full_mask(n_data, pairwise_label_scores) + else: + # diag_mask that covers all pairs + # (only selfs/diags are 0s) + mask, pair_count = masks.diag_mask(n_data, pairwise_label_scores) + + # pairwise sigmoid_cross_entropy_with_logits loss + loss = tf.cond(tf.equal(pair_count, 0), lambda: 0., + lambda: _get_average_cross_entropy_loss(pairwise_label_scores, + pairwise_predicted_scores, mask, pair_count)) + return loss + + +def get_lambda_pair_loss(pairwise_label_scores, pairwise_predicted_scores, + params, swapped_ndcg): + """ + Paiwise learning-to-rank lambdarank loss + faster than the previous gradient method + Note: this loss depends on ranknet cross-entropy + delta NDCG is applied to ranknet cross-entropy + Hence, it is still a gradient descent method + Check paper http://citeseerx.ist.psu.edu/viewdoc/ + download?doi=10.1.1.180.634&rep=rep1&type=pdf for more information + for more information + Args: + pairwise_label_scores: a dense tensor of shape [n_data, n_data] + pairwise_predicted_scores: a dense tensor of shape [n_data, n_data] + n_data is the number of tweet candidates in a BatchPredictionRequest + params: network parameters + swapped_ndcg: swapped ndcg of shape [n_data, n_data] + ndcg values when swapping each pair in the prediction ranking order + mask options: full_mask and diag_mask + Returns: + average loss over pairs defined by the masks + """ + n_data = tf.shape(pairwise_label_scores)[0] + if params.mask == "full_mask": + # full_mask that only covers pairs that have different labels + # (all pairwise_label_scores = 0.5: selfs and same labels are 0s) + mask, pair_count = masks.full_mask(n_data, pairwise_label_scores) + else: + # diag_mask that covers all pairs + # (only selfs/diags are 0s) + mask, pair_count = masks.diag_mask(n_data, pairwise_label_scores) + + # pairwise sigmoid_cross_entropy_with_logits loss + loss = tf.cond(tf.equal(pair_count, 0), lambda: 0., + lambda: _get_average_cross_entropy_loss(pairwise_label_scores, + pairwise_predicted_scores, mask, pair_count, swapped_ndcg)) + return loss + + +def _get_average_cross_entropy_loss(pairwise_label_scores, pairwise_predicted_scores, + mask, pair_count, swapped_ndcg=None): + """ + Average the loss for a batchPredictionRequest based on a desired number of pairs + """ + loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=pairwise_label_scores, + logits=pairwise_predicted_scores) + loss = mask * loss + if swapped_ndcg is not None: + loss = loss * swapped_ndcg + loss = tf.reduce_sum(loss) / pair_count + return loss + + +def get_listmle_loss(labels, predicted_scores): + r""" + listwise learning-to-rank listMLE loss + Note: Simplified MLE formula is used in here (omit the proof in here) + \sum_{s=1}^{n-1} (-predicted_scores + ln(\sum_{i=s}^n exp(predicted_scores))) + n is tf.shape(predicted_scores)[0] + Check paper http://icml2008.cs.helsinki.fi/papers/167.pdf for more information + Args: + labels: a dense tensor of shape [n_data, 1] + n_data is the number of tweet candidates in a BatchPredictionRequest + predicted_scores: a dense tensor of same shape and type as labels + Returns: + average loss + """ + labels = tf.reshape(labels, [-1, 1]) + n_data = tf.shape(labels)[0] + predicted_scores = tf.reshape(predicted_scores, [-1, 1]) + + predicted_scores_ordered_by_labels = _get_ordered_predicted_scores(labels, + predicted_scores, n_data) + + loss = (-1) * tf.reduce_sum(predicted_scores) + # sum over 1 to n_data - 1 + temp = tf.gather(predicted_scores_ordered_by_labels, [n_data - 1]) + temp = tf.reshape(temp, []) + loss = tf.add(loss, temp) + + exps = tf.exp(predicted_scores_ordered_by_labels) + exp_sum = tf.reduce_sum(exps) + # clip exp_sum for safer log + loss = tf.add(loss, math_fns.safe_log(exp_sum)) + + iteration = tf.constant(0) + + def _cond(iteration, loss, exp_sum, exp): + return tf.less(iteration, n_data - 2) + + def _gen_loop_body(): + def loop_body(iteration, loss, exp_sum, exps): + temp = tf.gather(exps, [iteration]) + temp = tf.reshape(temp, []) + exp_sum = tf.subtract(exp_sum, temp) + # clip exp_sum for safer log + loss = tf.add(loss, math_fns.safe_log(exp_sum)) + return tf.add(iteration, 1), loss, exp_sum, exps + return loop_body + + iteration, loss, exp_sum, exps = tf.while_loop(_cond, _gen_loop_body(), + (iteration, loss, exp_sum, exps)) + loss = loss / tf.cast(n_data, dtype=tf.float32) + return loss + + +def _get_ordered_predicted_scores(labels, predicted_scores, n_data): + """ + Order predicted_scores based on sorted labels + """ + sorted_labels, ordered_labels_indices = tf.nn.top_k( + tf.transpose(labels), k=n_data) + ordered_labels_indices = tf.transpose(ordered_labels_indices) + predicted_scores_ordered_by_labels = tf.gather_nd(predicted_scores, + ordered_labels_indices) + return predicted_scores_ordered_by_labels + + +def get_attrank_loss(labels, predicted_scores, weights=None): + """ + Modified listwise learning-to-rank AttRank loss + Check paper https://arxiv.org/abs/1804.05936 for more information + Note: there is an inconsistency between the paper statement and + their public code + Args: + labels: a dense tensor of shape [n_data, 1] + n_data is the number of tweet candidates in a BatchPredictionRequest + predicted_scores: a dense tensor of same shape and type as labels + weights: a dense tensor of the same shape as labels + Returns: + average loss + """ + # The authors immeplemented the following, which is basically listnet + # attention_labels = _get_attentions(labels) + # attention_labels = tf.reshape(attention_labels, [1, -1]) + # predicted_scores = tf.reshape(predicted_scores, [1, -1]) + # loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=attention_labels, + # logits=predicted_scores)) + + # The paper proposed the following + # attention_labels = _get_attentions(labels) + # # However the following line is wrong based on their statement + # # as _get_attentions can give 0 results when input < 0 + # # and the result cannot be used in _get_attrank_cross_entropy + # # log(a_i^S) + # # attention_predicted_scores = _get_attentions(predicted_scores) + # loss = _get_attrank_cross_entropy(attention_labels, attention_predicted_scores) + # # the range of attention_predicted_scores is [0, 1) + # # this gives sigmoid [0.5, 0.732) + # # hence, it is not good to use in sigmoid_cross_entropy_with_logits either + + # Implemented the following instead + # _get_attentions is applied to labels + # softmax is applied to predicted_scores + reshaped_labels = tf.reshape(labels, [1, -1]) + attention_labels = _get_attentions(reshaped_labels) + reshaped_predicted_scores = tf.reshape(predicted_scores, [1, -1]) + attention_predicted_scores = tf.nn.softmax(reshaped_predicted_scores) + loss = _get_attrank_cross_entropy(attention_labels, attention_predicted_scores) + return loss + + +def _get_attentions(raw_scores): + """ + Used in attention weights in AttRank loss + for a query/batch/batchPreidictionRequest + (a rectified softmax function) + """ + not_consider = tf.less_equal(raw_scores, 0) + mask = tf.ones(tf.shape(raw_scores)) - tf.cast(not_consider, dtype=tf.float32) + mask = tf.cast(mask, dtype=tf.float32) + expon_labels = mask * tf.exp(raw_scores) + + expon_label_sum = tf.reduce_sum(expon_labels) + # expon_label_sum is safe as a denominator + attentions = math_fns.safe_div(expon_labels, expon_label_sum) + return attentions + + +def _get_attrank_cross_entropy(labels, logits): + # logits is not safe based on their satement + # do not use this function directly elsewhere + results = labels * math_fns.safe_log(logits) + (1 - labels) * math_fns.safe_log(1 - logits) + results = (-1) * results + results = tf.reduce_mean(results) + return results + + +def get_listnet_loss(labels, predicted_scores, weights=None): + """ + Listwise learning-to-rank listet loss + Check paper https://www.microsoft.com/en-us/research/ + wp-content/uploads/2016/02/tr-2007-40.pdf + for more information + Args: + labels: a dense tensor of shape [n_data, 1] + n_data is the number of tweet candidates in a BatchPredictionRequest + predicted_scores: a dense tensor of same shape and type as labels + weights: a dense tensor of the same shape as labels + Returns: + average loss + """ + # top one probability is the same as softmax + labels_top_one_probs = _get_top_one_probs(labels) + predicted_scores_top_one_probs = _get_top_one_probs(predicted_scores) + + if weights is None: + loss = tf.reduce_mean( + _get_listnet_cross_entropy(labels=labels_top_one_probs, + logits=predicted_scores_top_one_probs)) + return loss + + loss = tf.reduce_mean( + _get_listnet_cross_entropy(labels=labels_top_one_probs, + logits=predicted_scores_top_one_probs) * weights) / tf.reduce_mean(weights) + return loss + + +def _get_top_one_probs(labels): + """ + Used in listnet top-one probabilities + for a query/batch/batchPreidictionRequest + (essentially a softmax function) + """ + expon_labels = tf.exp(labels) + expon_label_sum = tf.reduce_sum(expon_labels) + # expon_label_sum is safe as a denominator + attentions = expon_labels / expon_label_sum + return attentions + + +def _get_listnet_cross_entropy(labels, logits): + """ + Used in listnet + cross entropy on top-one probabilities + between ideal/label top-one probabilities + and predicted/logits top-one probabilities + for a query/batch/batchPreidictionRequest + """ + # it is safe to use log on logits + # that come from _get_top_one_probs + # do not use this function directly elsewhere + results = (-1) * labels * math_fns.safe_log(logits) + return results + + +def get_pointwise_loss(labels, predicted_scores, weights=None): + """ + Pointwise learning-to-rank pointwise loss + Args: + labels: a dense tensor of shape [n_data, 1] + n_data is the number of tweet candidates in a BatchPredictionRequest + predicted_scores: a dense tensor of same shape and type as labels + weights: a dense tensor of the same shape as labels + Returns: + average loss + """ + if weights is None: + loss = tf.reduce_mean( + tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, + logits=predicted_scores)) + return loss + loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, + logits=predicted_scores) * weights) / tf.reduce_mean(weights) + return loss diff --git a/twml/twml/contrib/utils/masks.py b/twml/twml/contrib/utils/masks.py new file mode 100644 index 000000000..f3143dc52 --- /dev/null +++ b/twml/twml/contrib/utils/masks.py @@ -0,0 +1,38 @@ +import tensorflow.compat.v1 as tf + + +def diag_mask(n_data, pairwise_label_scores): + """ + This is so far only used in pariwise learning-to-rank + Args: + n_data: a int `Tensor`. + pairwise_label_scores: a dense `Tensor` of shape [n_data, n_data]. + Returns: + values in pairwise_label_scores except the diagonal + each cell contains a paiwise score difference + only selfs/diags are 0s + """ + mask = tf.ones([n_data, n_data]) - tf.diag(tf.ones([n_data])) + mask = tf.cast(mask, dtype=tf.float32) + pair_count = tf.to_float(n_data) * (tf.to_float(n_data) - 1) + pair_count = tf.cast(pair_count, dtype=tf.float32) + return mask, pair_count + + +def full_mask(n_data, pairwise_label_scores): + """ + This is so far only used in pariwise learning-to-rank + Args: + n_data: a int `Tensor`. + pairwise_label_scores: a dense `Tensor` of shape [n_data, n_data]. + Returns: + values in pairwise_label_scores except pairs that have the same labels + each cell contains a paiwise score difference + all pairwise_label_scores = 0.5: selfs and same labels are 0s + """ + not_consider = tf.equal(pairwise_label_scores, 0.5) + mask = tf.ones([n_data, n_data]) - tf.cast(not_consider, dtype=tf.float32) + mask = tf.cast(mask, dtype=tf.float32) + pair_count = tf.reduce_sum(mask) + pair_count = tf.cast(pair_count, dtype=tf.float32) + return mask, pair_count diff --git a/twml/twml/contrib/utils/math_fns.py b/twml/twml/contrib/utils/math_fns.py new file mode 100644 index 000000000..2d9e72282 --- /dev/null +++ b/twml/twml/contrib/utils/math_fns.py @@ -0,0 +1,171 @@ +import tensorflow.compat.v1 as tf +from tensorflow.python.ops import array_ops, math_ops + + +# Copied from metrics_impl.py +# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/metrics_impl.py#L216 +def safe_div(numerator, denominator, name=None): + """ + Example usage: calculating NDCG = DCG / IDCG to handle cases when + IDCG = 0 returns 0 instead of Infinity + Do not use this dividing funciton unless it makes sense to your problem + Divides two tensors element-wise, returns 0 if the denominator is <= 0. + Args: + numerator: a real `Tensor`. + denominator: a real `Tensor`, with dtype matching `numerator`. + name: Name for the returned op. + Returns: + 0 if `denominator` <= 0, else `numerator` / `denominator` + """ + t = math_ops.truediv(numerator, denominator) + zero = array_ops.zeros_like(t, dtype=denominator.dtype) + condition = math_ops.greater(denominator, zero) + zero = math_ops.cast(zero, t.dtype) + return array_ops.where(condition, t, zero, name=name) + + +def cal_ndcg(label_scores, predicted_scores, top_k_int=1): + """ + Calculate NDCG score for top_k_int ranking positions + Args: + label_scores: a real `Tensor`. + predicted_scores: a real `Tensor`, with dtype matching label_scores + top_k_int: An int or an int `Tensor`. + Returns: + a `Tensor` that holds DCG / IDCG. + """ + sorted_labels, predicted_order = _get_ranking_orders( + label_scores, predicted_scores, top_k_int=top_k_int) + + predicted_relevance = _get_relevance_scores(predicted_order) + sorted_relevance = _get_relevance_scores(sorted_labels) + + cg_discount = _get_cg_discount(top_k_int) + + dcg = _dcg_idcg(predicted_relevance, cg_discount) + idcg = _dcg_idcg(sorted_relevance, cg_discount) + # the ndcg score of the batch + # idcg is 0 if label_scores are all 0 + ndcg = safe_div(dcg, idcg, 'one_ndcg') + return ndcg + + +def cal_swapped_ndcg(label_scores, predicted_scores, top_k_int): + """ + Calculate swapped NDCG score in Lambda Rank for full/top k ranking positions + Args: + label_scores: a real `Tensor`. + predicted_scores: a real `Tensor`, with dtype matching label_scores + top_k_int: An int or an int `Tensor`. + Returns: + a `Tensor` that holds swapped NDCG by . + """ + sorted_labels, predicted_order = _get_ranking_orders( + label_scores, predicted_scores, top_k_int=top_k_int) + + predicted_relevance = _get_relevance_scores(predicted_order) + sorted_relevance = _get_relevance_scores(sorted_labels) + + cg_discount = _get_cg_discount(top_k_int) + + # cg_discount is safe as a denominator + dcg_k = predicted_relevance / cg_discount + dcg = tf.reduce_sum(dcg_k) + + idcg_k = sorted_relevance / cg_discount + idcg = tf.reduce_sum(idcg_k) + + ndcg = safe_div(dcg, idcg, 'ndcg_in_lambdarank_training') + + # remove the gain from label i then add the gain from label j + tiled_ij = tf.tile(dcg_k, [1, top_k_int]) + new_ij = (predicted_relevance / tf.transpose(cg_discount)) + + tiled_ji = tf.tile(tf.transpose(dcg_k), [top_k_int, 1]) + new_ji = tf.transpose(predicted_relevance) / cg_discount + + # if swap i and j, remove the stale cg for i, then add the new cg for i, + # remove the stale cg for j, and then add the new cg for j + new_dcg = dcg - tiled_ij + new_ij - tiled_ji + new_ji + + new_ndcg = safe_div(new_dcg, idcg, 'new_ndcg_in_lambdarank_training') + swapped_ndcg = tf.abs(ndcg - new_ndcg) + return swapped_ndcg + + +def _dcg_idcg(relevance_scores, cg_discount): + """ + Calculate DCG scores for top_k_int ranking positions + Args: + relevance_scores: a real `Tensor`. + cg_discount: a real `Tensor`, with dtype matching relevance_scores + Returns: + a `Tensor` that holds \\sum_{i=1}^k \frac{relevance_scores_k}{cg_discount} + """ + # cg_discount is safe + dcg_k = relevance_scores / cg_discount + return tf.reduce_sum(dcg_k) + + +def _get_ranking_orders(label_scores, predicted_scores, top_k_int=1): + """ + Calculate DCG scores for top_k_int ranking positions + Args: + label_scores: a real `Tensor`. + predicted_scores: a real `Tensor`, with dtype matching label_scores + top_k_int: an integer or an int `Tensor`. + Returns: + two `Tensors` that hold sorted_labels: the ground truth relevance socres + and predicted_order: relevance socres based on sorted predicted_scores + """ + # sort predictions_scores and label_scores + # size [batch_size/num of DataRecords, 1] + label_scores = tf.reshape(label_scores, [-1, 1]) + predicted_scores = tf.reshape(predicted_scores, [-1, 1]) + # sorted_labels contians the relevance scores of the correct order + sorted_labels, ordered_labels_indices = tf.nn.top_k( + tf.transpose(label_scores), k=top_k_int) + sorted_labels = tf.transpose(sorted_labels) + # sort predicitons and use the indices to obtain the relevance scores of the predicted order + sorted_predictions, ordered_predictions_indices = tf.nn.top_k( + tf.transpose(predicted_scores), k=top_k_int) + ordered_predictions_indices_for_labels = tf.transpose(ordered_predictions_indices) + # predicted_order contians the relevance scores of the predicted order + predicted_order = tf.gather_nd(label_scores, ordered_predictions_indices_for_labels) + return sorted_labels, predicted_order + + +def _get_cg_discount(top_k_int=1): + r""" + Calculate discounted gain factor for ranking position till top_k_int + Args: + top_k_int: An int or an int `Tensor`. + Returns: + a `Tensor` that holds \log_{2}(i + 1), i \in [1, k] + """ + log_2 = tf.log(tf.constant(2.0, dtype=tf.float32)) + # top_k_range needs to start from 1 to top_k_int + top_k_range = tf.range(top_k_int) + 1 + top_k_range = tf.reshape(top_k_range, [-1, 1]) + # cast top_k_range to float + top_k_range = tf.cast(top_k_range, dtype=tf.float32) + cg_discount = tf.log(top_k_range + 1.0) / log_2 + return cg_discount + + +def _get_relevance_scores(scores): + return 2 ** scores - 1 + + +def safe_log(raw_scores, name=None): + """ + Calculate log of a tensor, handling cases that + raw_scores are close to 0s + Args: + raw_scores: An float `Tensor`. + Returns: + A float `Tensor` that hols the safe log base e of input + """ + epsilon = 1E-8 + clipped_raw_scores = tf.maximum(raw_scores, epsilon) + return tf.log(clipped_raw_scores) diff --git a/twml/twml/contrib/utils/normalizer.py b/twml/twml/contrib/utils/normalizer.py new file mode 100644 index 000000000..a6a7035b8 --- /dev/null +++ b/twml/twml/contrib/utils/normalizer.py @@ -0,0 +1,39 @@ +import tensorflow.compat.v1 as tf +from twml.contrib.utils import math_fns + + +def mean_max_normalizaiton(dense_tensor): + """ + In-batch normalization + Args: + dense_tensor: A dense `Tensor`. + Returns: + (dense_tensor - mean) / abs(max value) + Note: + when dense_tensor is of size [1, ?] it will give 0 + If this is not what you want handle it outside the function + """ + dense_mean = tf.reduce_mean(dense_tensor, reduction_indices=[0]) + dense_abs_max = tf.abs(tf.reduce_max(dense_tensor, reduction_indices=[0])) + dense_tensor = math_fns.safe_div(dense_tensor - dense_mean, dense_abs_max, + 'mean_max_normalization_in_batch') + return dense_tensor + + +def standard_normalizaiton(dense_tensor): + """ + In-batch normalization + z-normalization or standard_normalization in batch + Args: + dense_tensor: A dense `Tensor`. + Returns: + (dense_tensor - mean) / variance + Note: + when dense_tensor is of size [1, ?] it will give 0 + If this is not what you want handle it outside the function + """ + epsilon = 1E-7 + dense_mean, dense_variance = tf.nn.moments(dense_tensor, 0) + # using epsilon is safer than math_fns.safe_div in here + dense_tensor = (dense_tensor - dense_mean) / (dense_variance + epsilon) + return dense_tensor diff --git a/twml/twml/contrib/utils/scores.py b/twml/twml/contrib/utils/scores.py new file mode 100644 index 000000000..84e792c13 --- /dev/null +++ b/twml/twml/contrib/utils/scores.py @@ -0,0 +1,33 @@ +import tensorflow.compat.v1 as tf + + +def get_pairwise_scores(tensor_input): + """ + This is so far used in pariwise learning-to-rank + + Arguments: + tensor_input: a dense `Tensor` of shape [n_data, 1] + n_data is the number of teet candidates + + Returns: + pairwise scores: a dense `Tensor` of shape [n_data, n_data]. + """ + return tensor_input - tf.transpose(tensor_input) + + +def get_pairwise_label_scores(labels): + """ + This is so far used in pariwise learning-to-rank + Args: + labels: a dense `Tensor` of shape [n_data, 1] + n_data is the number of teet candidates + Returns: + pairwise label scores: a dense `Tensor` of shape [n_data, n_data]. + each value is within [0, 1] + """ + # raw pairwise label scores/differences + pairwise_label_scores = get_pairwise_scores(labels) + # sanity check to make sure values in differences_ij are [-1, 1] + differences_ij = tf.maximum(tf.minimum(1.0, pairwise_label_scores), -1.0) + # values in pairwise_label_scores are within [0, 1] for cross entropy + return (1.0 / 2.0) * (1.0 + differences_ij) diff --git a/twml/twml/contrib/utils/similarities.py b/twml/twml/contrib/utils/similarities.py new file mode 100644 index 000000000..212065f88 --- /dev/null +++ b/twml/twml/contrib/utils/similarities.py @@ -0,0 +1,17 @@ +import tensorflow.compat.v1 as tf + + +def cosine_similarity(x1, x2, axis): + """ + cosine similarity of two tensors. + + Arguments: + x1: + A tf.Tensor + x2: + A tf.Tensor + axis: Dimension along which to normalize. + """ + normalize_x1 = tf.nn.l2_normalize(x1, axis=axis) + normalize_x2 = tf.nn.l2_normalize(x2, axis=axis) + return tf.reduce_sum(tf.multiply(normalize_x1, normalize_x2), axis=axis) diff --git a/twml/twml/dataset.py b/twml/twml/dataset.py new file mode 100644 index 000000000..4356fdc7c --- /dev/null +++ b/twml/twml/dataset.py @@ -0,0 +1,372 @@ +""" +This module implements custom tf.data.datasets for twml. +""" +import numbers + +from absl import logging +from kazoo.client import KazooClient +from libtwml import OPLIB +import tensorflow.compat.v1 as tf +from twml.constants import DEFAULT_ZOOKEEPER_BASE_ZNODE, DEFAULT_ZOOKEEPER_HOST + + +class BlockFormatDataset(tf.data.Dataset): + """A ``tf.data.Dataset`` comprising records from one or more TFRecord files.""" + + def __init__(self, filenames, compression_type="auto", buffer_size=1 << 20): + """ + Creates a ``BlockFormatDataset``. + + Args: + filenames: + A `tf.string` tensor containing one or more filenames. + compression_type: + A string specifying the compression type. + Can be one of 'gz' (or 'gzip'), 'none', 'auto' (default). + When compression_type == 'auto', it is inferred from file extension. + buffer_size: + Buffer size to be used during decompression. default: 1<<20. + """ + self._filenames = tf.convert_to_tensor(filenames, dtype=tf.string, name="filenames") + self._compression_type = tf.convert_to_tensor(compression_type.lower(), name="compression_type") + self._buffer_size = tf.convert_to_tensor(buffer_size, dtype=tf.int64, name="buffer_size") + # Parent class calss self._as_variant_tensor in init. So call this at the end. + super(BlockFormatDataset, self).__init__() + + def _as_variant_tensor(self): + """ + Create the resource handle for the dataset. + """ + try: + block_format_dataset = __import__("libtwml_internal").OPLIB.block_format_dataset + return block_format_dataset(self._filenames) + except ImportError: + block_format_dataset = OPLIB.block_format_dataset_v2 + return block_format_dataset(self._filenames, self._compression_type, self._buffer_size) + + def _inputs(self): + return [] + + @property + def output_shapes(self): + """Return output shapes""" + return tf.TensorShape([]) + + @property + def output_types(self): + """Return output types""" + return tf.string + + @property + def output_classes(self): + """Return output classes""" + return tf.Tensor + + +def downsample_dataset(dataset, sample_rate, rate_name): + """ + Downsample a tf.data.Dataset at sample_rate + """ + if sample_rate is None or sample_rate == 1.0: + return dataset + elif not isinstance(sample_rate, numbers.Real): + raise TypeError("dataset %s must be a real number" % rate_name) + elif sample_rate <= 0 or sample_rate > 1: + raise ValueError("dataset %s must be in range (0, 1])" % rate_name) + return dataset.filter(lambda _: tf.squeeze(tf.random_uniform([1])) < sample_rate) + + +def _filenames_dataset(files, shards=None, shard_index=None): + """ + Get a tf.data.Dataset with file names from a list of files + Optionally shard the file list (see stream_block_format_dataset) + """ + files = tf.data.Dataset.from_tensor_slices(files) + + if [shards, shard_index] != [None, None]: + logging.info("Sharding files dataset (index: %d, shards: %d)" % (shard_index, shards)) + files = files.shard(num_shards=shards, index=shard_index) + + return files + + +def stream_block_format_dataset( + files, parse_fn, batch_size, num_threads, + shuffle=True, repeat=False, + block_length=None, part_file_parallelism=None, file_shuffle_size=None, + record_shuffle_size=None, dataset_fn=None, + keep_rate=None, parts_downsampling_rate=None, prefetch_size=2, + shards=None, shard_index=None, shuffle_files=True, interleave=True): + """ + Helper function to stream a list of part files. + + Args: + files: + List of input files which will create a dataset. + parse_fn: + A function that takes a byte tensor containing a datarecord and decodes it. + batch_size: + The batch size for each step. + num_threads: + Number of threads working on the data in parallel. + shuffle: + Shuffle records within each file using ``record_shuffle_size``. Defaults to True. + repeat: + Repeat the dataset indefinitely. Defaults to False. + Useful when you want to use an ``[train,eval]_steps`` greater than the size of the dataset + (otherwise ``Estimator.[train,evaluate]`` stop when the end of the dataset is reached). + block_length (optional): + Number of consecutive records to pull from a single part file. + Defaults to batch_size. + part_file_parallelism (optional): + Number of part files to read from in parallel. Once a part file is completely read, it will + be replaced by the next part file in the part file list. + + ``num_threads`` specifies a reader thread pool size, while ``part_file_parallelism`` specifies + the number of files to read from in parallel. If ``part_file_parallelism`` is greater than or + equal to ``num_threads``, the reads will be distributed over ``num_threads``. On the other hand, + if ``part_file_parallelism`` is smaller than``num_threads``, it is very likely that the reader + thread pool will be underutilized, since it can never be the case that every reader thread has + a part file to read from. + + file_shuffle_size (optional): + the buffer_size used for shuffling of the list of files. + Defaults to 1000. For example, if you have 2000 files, the first + 1000 files are shuffled together, iterated through, then the next 1000 files are shuffled + and iterated through. + record_shuffle_size (optional): + the ``buffer_size`` used for shuffling records in each thread. + Defaults to ``batch_size * 8`` records. + dataset_fn (optional): + A function of that modifies the dataset after it reads different interleaved parts files. + Defaults to: + + .. code-block:: python + + def dataset_fn(dataset, parse_fn, batch_size): + return dataset.batch(batch_size).map(parse_fn, 1) + + keep_rate (optional): + A float value in (0.0, 1.0] that indicates to drop records according to the Bernoulli + distribution with p = 1 - keep_rate. + Defaults to None (no records dropped). + + parts_downsampling_rate (optional): + A float value in ``(0.0, 1.0]`` that indicates the factor by which to downsample part files. + For example, a value of 0.2 means only 20 percent of part files become part of the dataset. + + Note that this argument is only useful in conjunction with a [train,eval]_steps of -1 + (that is, when the entire dataset is used). Furthermore, note that even in this case, each + epoch will see a different set of part files. This is because new part files are re-sampled + every epoch. In other words, this argument is only provided for backwards compatibility with + DeepBird v1. We recommend you use a smaller [train,eval]_steps (or specify a keep_rate) + instead. + + shards (optional): + Number of partitions to shard the dataset into. This is useful for codistillation and other + techniques that require each worker to train on disjoint partitions of the dataset. + The dataset is not sharded by default. + + shard_index (optional): + Which partition of the dataset to use if ``shards`` is set. + + shuffle_files (optional): + Shuffle the list of files. Defaults to True. + When False, files are iterated in the order they are passed in. + + interleave (optional): + Interleave records from multiple files in parallel. Defaults to True. + + Returns: + tf.data.DataSet of batches of HashedDataRecord resource handles decoded and streamed online. + """ + # Creating a dataset from an input directory + + files = _filenames_dataset(files, shards=shards, shard_index=shard_index) + + file_shuffle_size = file_shuffle_size if file_shuffle_size is not None else 100000 + record_shuffle_size = record_shuffle_size if record_shuffle_size is not None else (batch_size * 8) + block_length = block_length if block_length is not None else batch_size + + logging.info("NUM_THREADS: %d", num_threads) + + if repeat: + files = files.repeat() + + if shuffle_files: + # Randomly shuffle the files list. + files = files.shuffle(buffer_size=file_shuffle_size) + + # Downsample parts files + files = downsample_dataset(files, parts_downsampling_rate, "parts_downsampling_rate") + + # Interleave the result from BlockFormatDataset + # block_length == batch_size results in batch_size records being read from a single file. + def map_fn(filenames): + '''function that maps each filename to a BlockFormatDataset''' + # reach each file using BlockFormatDataset + dataset = BlockFormatDataset(filenames) + + # early prefetching can sometimes improve performance (like on GCS) + dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) + + # Shuffling before repeating ensures strong ordering. + if shuffle: + dataset = dataset.shuffle(buffer_size=record_shuffle_size) + + return dataset + + if interleave: + part_file_parallelism = num_threads if part_file_parallelism is None else part_file_parallelism + dataset = files.interleave( + map_fn, cycle_length=part_file_parallelism, block_length=block_length, num_parallel_calls=num_threads) + else: + dataset = files.flat_map(map_fn) + + # Downsample DataRecords + dataset = downsample_dataset(dataset, keep_rate, "keep_rate") + + if dataset_fn is None: + # Create a batch of datarecords and decode them + return dataset.batch(batch_size).map(parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(prefetch_size) + + return dataset_fn(dataset, parse_fn, batch_size) + + +def cx_zk_path(path): + if path is None: + raise ValueError("Path for zookeeper dataset pointer is None. You must specify a path.") + return_path = "/".join([DEFAULT_ZOOKEEPER_BASE_ZNODE, path]) + logging.info("Zookeeper path is: {}".format(return_path)) + return return_path + + +def zookeeper_ordered_dataset( + files, parse_fn, batch_size, zk_counter_path, repeat=False, + num_threads=2, block_length=None, part_file_parallelism=None, + batch_shuffle_size=None, file_keep_rate=None, record_keep_rate=None, + prefetch_size=2, interleave=False, dataset_fn=None, verbose=False): + """ + Make a tf.Dataset given an ordered list of filenames, using Zookeeper to keep track of + which file to read, and to coordinate multiple workers. + + Args: + files: + ordered list of (typically HDFS) filenames. This must remain consistent + between different workers, and between worker restarts (e.g. in the case + of instance failure or preemption). + To ensure this remains consistent, consider using the --train.files_list + option from DataRecordTrainer. + parse_fn: + A function that takes a byte tensor containing a datarecord and decodes it. + batch_size: + The batch size for each step. + zk_counter_path: + Path under the root node for the underlying zookeeper shared counter that + is used to coordinate distributed iteration over the list of files. + Full path will be `'/'.join([DEFAULT_ZOOKEEPER_BASE_ZNODE, zk_counter_path])`. + repeat: + Default False. Set True to repeat over the files forever. + num_threads: + Default 2. Number of threads working on the data in parallel. + Only used if interleave=True. + block_length: + Default None. Number of consecutive records to pull from a single part file. + If None, then block_length=batch_size will be used. + Only used if interleave=True. + part_file_parallelism: + Default None. Number of part files to read from in parallel. Once a part file is completely + read, it will be replaced by the next part file indicated by the zookeeper counter. + Only used if interleave=True. + + ``num_threads`` specifies a reader thread pool size, while ``part_file_parallelism`` specifies + the number of files to read from in parallel. If ``part_file_parallelism`` is greater than or + equal to ``num_threads``, the reads will be distributed over ``num_threads``. On the other hand, + if ``part_file_parallelism`` is smaller than``num_threads``, it is very likely that the reader + thread pool will be underutilized, since it can never be the case that every reader thread has + a part file to read from. + + batch_shuffle_size: + Default None. Size of shuffle buffer, for shuffling that will be applied after batching. + if None, then batches will not be shuffled. Ignored if dataset_fn is provided. + file_keep_rate: + Default None. Fraction of files to keep, or None to keep all files. + record_keep_rate: + Default None. Fraction of records to keep, or None to keep all records. + prefetch_size: + Default 2. Number of parsed batches to prefetch. Ignored if dataset_fn is provided. + interleave: + Default False. Set True to use tf.data.Dataset.interleave rather than flat_map. + dataset_fn: + A function that is applied to the dataset of individual records, after + these have been read from the parts files. + If ``None`` (the default), the behavior will be as though dataset_fn were set to: + + .. code-block:: python + + def dataset_fn(dataset, parse_fn, batch_size): + dataset = dataset.batch(batch_size) + dataset = dataset.map(parse_fn, tf.data.experimental.AUTOTUNE) + if batch_shuffle_size: + dataset = dataset.shuffle(batch_shuffle_size) + return dataset.prefetch(prefetch_size) + + verbose: + Default False. Set True to log the names of files loaded by TF. + """ + block_length = batch_size if block_length is None else block_length + part_file_parallelism = num_threads if part_file_parallelism is None else part_file_parallelism + + def zk_index_generator(my_files=files): + zk = KazooClient(hosts=DEFAULT_ZOOKEEPER_HOST) + zk.start() + my_counter = zk.Counter(cx_zk_path(zk_counter_path), default=0) + while True: + my_counter += 1 + counter_pre_value = my_counter.pre_value + if repeat: + counter_pre_value = counter_pre_value % len(my_files) + if counter_pre_value >= len(my_files): + break + else: + chosen_file = my_files[counter_pre_value] + if verbose: + logging.info("{}. yielding {}".format(counter_pre_value, chosen_file)) + yield chosen_file + zk.stop() + + files = tf.data.Dataset.from_generator(zk_index_generator, tf.string) + + # Downsample parts files + files = downsample_dataset(files, file_keep_rate, "file_keep_rate") + + def map_fn(filenames): + return BlockFormatDataset(filenames).prefetch(20) + + # Dont interleave for sequential training + if interleave: + dataset = files.interleave( + map_fn, + cycle_length=part_file_parallelism, + block_length=block_length, + num_parallel_calls=num_threads) + else: + dataset = files.flat_map(map_fn) + + # Downsample DataRecords + dataset = downsample_dataset(dataset, record_keep_rate, "record_keep_rate") + + if dataset_fn is None: + # Create a batch of datarecords and decode them + dataset = dataset.batch(batch_size) + dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) + # shuffle after batching and parsing for performance reasons + # faster b/c 1 random selection is made per batch rather than per record + if batch_shuffle_size: + dataset = dataset.shuffle(buffer_size=batch_shuffle_size) + dataset = dataset.prefetch(prefetch_size) + + else: + dataset = dataset_fn(dataset, parse_fn, batch_size) + + return dataset diff --git a/twml/twml/errors.py b/twml/twml/errors.py new file mode 100644 index 000000000..9b50fcd79 --- /dev/null +++ b/twml/twml/errors.py @@ -0,0 +1,13 @@ +""" +Error classes for twml +""" + + +class EarlyStopError(Exception): + """Exception used to indicate evaluator needs to early stop.""" + pass + + +class CheckpointNotFoundError(Exception): + """Exception used to indicate a checkpoint hasnt been found.""" + pass diff --git a/twml/twml/export_output_fns.py b/twml/twml/export_output_fns.py new file mode 100644 index 000000000..f72e1d0fe --- /dev/null +++ b/twml/twml/export_output_fns.py @@ -0,0 +1,17 @@ +''' +Contains implemenations of DataRecordTrainer.get_export_output_fns that specify how to +export model graph outputs from build_graph to DataRecords for prediction servers. + +Modelers can use the functions in this module as the export_output_fn parameter of +the DataRecordTrainer constructor to customize how to export their model outputs. + +Modelers may also provide a custom implementation of export_output_fn using these as reference. +''' + +# pylint: disable=invalid-name +from twitter.deepbird.io.legacy.export_output_fns import ( + batch_prediction_continuous_output_fn, # noqa: F401 + batch_prediction_tensor_output_fn, # noqa: F401 + default_output_fn, # noqa: F401 + variable_length_continuous_output_fn, # noqa: F401 +) diff --git a/twml/twml/feature_config.py b/twml/twml/feature_config.py new file mode 100644 index 000000000..37004f442 --- /dev/null +++ b/twml/twml/feature_config.py @@ -0,0 +1,54 @@ +""" +Feature configuration for DeepBird jobs: +- Which features to keep +- Which features to blacklist +- Which features are labels +- Which feature is the weight +""" + +from twitter.deepbird.io.legacy import feature_config + + +class FeatureConfig(feature_config.FeatureConfig): + def get_feature_spec(self): + """ + Generates a serialization-friendly dict representing this FeatureConfig. + """ + doc = super(FeatureConfig, self).get_feature_spec() + # Override the class in the spec. + doc["class"] = "twml.FeatureConfig" + return doc + + +class FeatureConfigBuilder(feature_config.FeatureConfigBuilder): + def build(self): + # Overwrite self.build() to return twml.FeatureConfig instead + """ + Builds and returns FeatureConfig object. + """ + + ( + features, + tensor_types, + sparse_tensor_types, + feature_map, + feature_name_to_feature_parser, + feature_in_bq_name, + ) = self._build() + + return FeatureConfig( + features=features, + labels=self._labels, + weight=self._weight, + filters=self._filter_features, + tensor_types=tensor_types, + sparse_tensor_types=sparse_tensor_types, + feature_types=feature_map, + decode_mode=self._decode_mode, + legacy_sparse=self._legacy_sparse, + feature_name_to_feature_parser=self._feature_name_to_feature_parser, + feature_in_bq_name=self._feature_in_bq_name, + ) + + +_name_to_id = feature_config._name_to_id diff --git a/twml/twml/filters.py b/twml/twml/filters.py new file mode 100644 index 000000000..e48633808 --- /dev/null +++ b/twml/twml/filters.py @@ -0,0 +1,9 @@ +''' +Includes functions to filter features dict build from +data records. +''' + +from twitter.deepbird.io.legacy.filters import ( + balance_binary_class_samples, # noqa: F401 + sparse_keep_feature_if, # noqa: F401 + sparse_keep_sample_if) # noqa: F401 diff --git a/twml/twml/hooks.py b/twml/twml/hooks.py new file mode 100644 index 000000000..cdf733535 --- /dev/null +++ b/twml/twml/hooks.py @@ -0,0 +1,562 @@ +""" This file contains tf.train.SessionRunHooks defined by TWML """ +from datetime import datetime +import json +import operator +import os + +from absl import logging +import numpy as np +import tensorflow.compat.v1 as tf +from tensorflow.python.training.basic_session_run_hooks import NeverTriggerTimer, SecondOrStepTimer +import twml + + +class StepProgressHook(tf.train.SessionRunHook): + """Hook that displays a progress bar to monitor global step progress """ + + def __init__(self, max_step): + """ + Initializes a `StepProgressHook`. + This hook displays a progress bar for max_steps. + + Note that this hook only works for training and calibration. + + Args: + max_steps: + maximum steps to monitor in progress bar. + When this many steps is reached, the progress bar will be full. + """ + self._max_step = max_step + self._start_step = 0 + self._global_step_tensor = None + self._progress_bar = None + + def begin(self): + """ sets the global_step_tensor """ + self._global_step_tensor = tf.train.get_or_create_global_step() + if self._global_step_tensor is None: + raise RuntimeError("Global step should be created to use StepProgressHook.") + + def after_create_session(self, session, coord): + """ creates the progress bar and keeps track of the first global step upon session creation """ + global_step = session.run(self._global_step_tensor) + self._start_step = global_step + self._progress_bar = tf.keras.utils.Progbar(self._max_step) + + def before_run(self, run_context): # pylint: disable=unused-argument + """ invoked before calling session.run """ + return tf.train.SessionRunArgs(self._global_step_tensor) + + def after_run(self, run_context, run_values): + """ invoked after run is called. Updates the progress bar. """ + step = run_context.session.run(self._global_step_tensor) + self._progress_bar.update(step - self._start_step) + + +class GetMetricsHook(tf.train.SessionRunHook): + """ + Hook used to obtain evaluation metrics. + Typically used for early-stopping by obtaining the value of a + metric at the end of an epoch. + Note that the metric tensor and its commensurate update Op + are responsible for aggregating the metric during the session + (one session per epoch). Used for evaluation. + """ + + def __init__(self, get_metrics_fn): + """GetMetricsHook constructor. + + Args: + get_metrics_fn: + Function that returns a dict mapping metric keys to + tensors as a tf.Tensor. + See Trainer.learn for an example use-case. + """ + + self._get_metrics_fn = get_metrics_fn + self._metric_tensors = None + self.metric_values = None + + def begin(self): + """ sets the global_step_tensor and metric tensor""" + self._metric_tensors = self._get_metrics_fn() + assert isinstance(self._metric_tensors, dict) + + def end(self, session): + self.metric_values = session.run(self._metric_tensors) + + +class EarlyStopHook(GetMetricsHook): + """ + A GetMetricsHook augmented with early-stopping logic for use + within the Trainer.learn method. + """ + + def __init__(self, + metric, + patience, + minimize, + get_estimator_spec_fn, + checkpoint_dir, + file_path=None, + exit_on_end=True, + start_epoch=0, + tolerance=0): + """ + Prepare early-stopping hook and variables. + + Args: + metric: + String specifying the metric to early-stop on. Required with positive + ``early_stop_patience``. For example, 'accuracy', 'accuracy_0', 'loss', etc. + The string is used to extract the relevant tensor Op from the dict returned by + the get_eval_metric_ops method. For ``metrics`` pass to the constructor, + the string is one of those. For multi-class (that is, multi-metric) + metrics, the string may be appended with a ``_0``, ``_1``, etc. or one + of the ``multi_metric_names`` (one per class). + patience: + Maximum number of epochs to wait for an improvement in the early_stop_metric + before breaking off training. For example, a patience of 10 means that + training will have 10 epochs to improve the metric before it is killed. + Whenever the metric is improved before running out of patience, + patience is reset to ``early_stop_patience``. + minimize: + Set this to True for metrics that need to be minimized + (like ``loss``). Metrics like ``accuracy`` that need to be maximized + should set this to False. + tolerance: + A non-negative tolerance for comparing early_stop_metric. + e.g. when maximizing the condition is current_metric > best_metric + tolerance." + Defaults to 0. + get_estimator_spec_fn: + function that returns the current EstimatorSpec. + The EstimatorSpec is used to obtain the current eval_metric_ops. + checkpoint_dir: + path to directory containing the Estimator checkpoints. + file_path: + path to file that is used by this hook to communicate early-stopping + to StopIfExistsHook. This hook would be used for evaluation, while + the StopIfExistsHooks (the listeners) would be used for training. + When the file is created, the StopIfExistsHooks detect and terminate training. + This argument is used by ``Trainer.train_and_evaluate``. + exit_on_end: + when the end() method is called to indicate that the session is terminating, + and exit_on_end is True, twml.errors.EarlyStopError() is triggered to stop the evaluation job. + This is set to False by the trainer for non distributed jobs. + start_epoch: + Specifies the starting epoch number. This is used for logging purposes only. + """ + if not isinstance(metric, str): + raise ValueError("Expecting string for metric arg") + if not isinstance(patience, int): + raise ValueError("Expecting positive number for metric arg") + + self.should_stop = False + self._metric = metric + self._patience = patience + self._current_patience = patience + self._checkpoint_dir = checkpoint_dir + self._exit_on_end = exit_on_end + self._latest_checkpoint_path = None + # used for distributed training (tf.estimator.train_and_evaluate) + self._file_path = file_path + self._epoch = start_epoch + if self._file_path is not None: + # TODO try to read epoch from a file that we create + if tf.io.gfile.exists(self._file_path): + # delete the file if it exists (not sure this makes sense) + logging.info("EarlyStopHook: Removing existing file: %s.", self._file_path) + tf.io.gfile.remove(self._file_path) + + # best_checkpoint dir will contain the best checkpoint + self._best_checkpoint_path = os.path.join(checkpoint_dir, 'best_checkpoint') + self._eval_checkpoint_path = os.path.join(checkpoint_dir, 'eval_checkpoint') + self._best_metric_path = os.path.join(self._best_checkpoint_path, self._metric) + + if tf.io.gfile.exists(self._best_metric_path): + with tf.io.gfile.GFile(self._best_metric_path, mode="r") as f: + best_metric_from_file = float(f.read()) + else: + best_metric_from_file = None + + if minimize: + # current < best : is better + self._is_better_than = operator.lt + # worse metric possible + if best_metric_from_file is None: + self._best_metric = np.inf + else: + self._best_metric = best_metric_from_file - tolerance + # used for printing + self._early_stop_name = "minimum" + else: + # current > best : is better + self._is_better_than = operator.gt + # worse metric possible + if best_metric_from_file is None: + self._best_metric = -np.inf + else: + self._best_metric = best_metric_from_file + tolerance + # used for printing + self._early_stop_name = "maximum" + + def get_metrics_fn(): + """ function to get metric tensors to early-stopping """ + estimator_spec = get_estimator_spec_fn() + eval_metric_ops = estimator_spec.eval_metric_ops + if metric not in eval_metric_ops: + raise ValueError( + "Expecting early_stop_metric '%s' key in eval_metric_ops dict" + % (metric)) + # get the value_op from the (value_op, update_op) value + return {k: v[0] for k, v in eval_metric_ops.items()} + + # initialize GetMetricsHook to get current value of metric from session + super(EarlyStopHook, self).__init__(get_metrics_fn=get_metrics_fn) + + def early_stop(self, epoch): + """ + Looks at the current value of the early stopping metric. + Decrements current patience. If metric improves, patience is reset + and latest checkpoint is moved to checkpoint_dir/best_checkpoint. + If current patience reaches zero, returns True. + + Args: + epoch: + The current epoch number. + + Returns: + True when early-stopped. False otherwise. + """ + # decrement patience + self._current_patience -= 1 + + # get the current metric value + current_metric = self.metric_values[self._metric] + + if self._is_better_than(current_metric, self._best_metric): + # save best version of model + self._best_metric = current_metric + logging.info( + "Found new %s %s=%f @ epoch %d", + self._early_stop_name, self._metric, self._best_metric, epoch) + # backup the file to checkpoint_dir/best_checkpoint + assert self._latest_checkpoint_path, "expecting latest checkpoint" + logging.info("Backing up " + self._latest_checkpoint_path) + + try: + eval_checkpoint = tf.train.latest_checkpoint(self._eval_checkpoint_path) + twml.util.backup_checkpoint( + checkpoint_path_prefix=eval_checkpoint, + backup_path=self._best_checkpoint_path) + except twml.errors.CheckpointNotFoundError as ex: + msg = "Consider increasing 'keep_checkpoint_max' or 'save_checkpoint_secs'" + raise twml.errors.CheckpointNotFoundError(str(ex) + "\n" + msg) + + tf.io.gfile.makedirs(os.path.dirname(self._best_metric_path)) + with tf.io.gfile.GFile(self._best_metric_path, mode="w") as f: + # Write with enough precision + f.write("%.8f" % self._best_metric) + + # reset patience + self._current_patience = self._patience + + elif self._current_patience > 0: + logging.info("No new %s found after %d epochs", + self._early_stop_name, self._patience - self._current_patience) + elif self._current_patience == 0: + logging.info( + "No new %s found after %d epochs. Early-stopping experiment.", + self._early_stop_name, self._patience) + return True + + return False + + def cleanup_checkpoints(self): + """ + makes it so that the best checkpoint is the only checkpoint + in checkpoint_dir. + """ + raise NotImplementedError("cleanup_checkpoints is no longer supported") + + def end(self, session): + """ + This method is called at the end of an evaluation/epoch. + When file_path constructor argument is provided, this + will call ``early_stop()``. + When ``early_stop()`` returns True, it creates the file_path, + which will be detected by StopIfExistsHooks + and stop training for all workers and the chief. It will + also call ``cleanup_checkpoints()``. + """ + super(EarlyStopHook, self).end(session) + + # Checks for early stopping criteria and makes a backup + self.should_stop = self.early_stop(self._epoch) + + if self._file_path is not None: + if self.should_stop: + # create a file to inform workers + with tf.io.gfile.GFile(self._file_path, "wb") as gfile: + gfile.write("early-stop\n") + # makes the best checkpoint the only checkpoint in save_dir. + msg = "early-stopping evaluation at epoch %d" % self._epoch + logging.info(msg) + if self._exit_on_end: + raise twml.errors.EarlyStopError(msg) + else: + self._latest_checkpoint_path = None + + self._epoch += 1 + + def begin(self): + """ + Saves the latest_checkpoint in case it gets superseded by another checkpoint. + Remember that when used with train_and_evaluate, the chief saves checkpoints + continuouly. The chief could save a checkpoint after evaluation started. + So saving the checkpoint at the beginning of evaluation ensures that we + later save the correct best checkpoint. + """ + super(EarlyStopHook, self).begin() + self._latest_checkpoint_path = tf.train.latest_checkpoint(self._checkpoint_dir) + + assert self._latest_checkpoint_path, "expecting latest checkpoint" + # Backup to temporary directory + try: + twml.util.backup_checkpoint( + checkpoint_path_prefix=self._latest_checkpoint_path, + backup_path=self._eval_checkpoint_path) + except twml.errors.CheckpointNotFoundError as ex: + msg = "Consider increasing 'keep_checkpoint_max' or 'save_checkpoint_secs'" + raise twml.errors.CheckpointNotFoundError(str(ex) + "\n" + msg) + + +class MetricsUpdateHook(GetMetricsHook): + """ + A GetMetricsHook augmented with logic to map SessionRun events to metrics updates. + It is mainly used by `TrackRun` to persist model metrics via Model Repo. + """ + + def __init__(self, + get_estimator_spec_fn, + add_metrics_fn, + every_n_iter=None, + every_n_secs=None + ): + """ + Args: + get_estimator_spec_fn: + function that returns the current EstimatorSpec. + The EstimatorSpec is used to obtain the current eval_metric_ops. + add_metrics_fn: `function` callback used to report metrics, called automatically + at the end of every epoch. + every_n_iter: `int`, log the metrics once every N local + steps taken in the current epoch. + every_n_secs: `int` or `float`, log the metrics once every N + seconds passed in the current epoch. Exactly one of `every_n_iter` and `every_n_secs` + should be provided. + Raises: + ValueError: if `every_n_iter` is non-positive or if not exactly one of `every_n_iter` and + `every_n_secs` is set when `add_progress_metrics_fn` is provided. + """ + only_log_at_end = (every_n_iter is None) and (every_n_secs is None) + + if (not only_log_at_end and every_n_iter and every_n_secs): + raise ValueError( + 'exactly one of every_n_iter and every_n_secs must be provided' + ) + + # TODO: should have a minimum to avoid too many calls to ModelRepo? + if every_n_iter is not None and every_n_iter <= 0: + raise ValueError("invalid every_n_iter=%s." % every_n_iter) + + self._timer = ( + NeverTriggerTimer() if only_log_at_end else + SecondOrStepTimer(every_secs=every_n_secs, every_steps=every_n_iter) + ) + + self._should_trigger = False + self._iter_count = 0 + + self._add_metrics_fn = add_metrics_fn + + def get_metrics_fn(): + """ + Function that returns the current EstimatorSpec. + The EstimatorSpec is used to obtain the current eval_metric_ops. + """ + estimator_spec = get_estimator_spec_fn() + eval_metric_ops = estimator_spec.eval_metric_ops + # get the value_op from the (value_op, update_op) value + return {k: v[0] for k, v in eval_metric_ops.items()} + super(MetricsUpdateHook, self).__init__(get_metrics_fn=get_metrics_fn) + + def report_metrics(self): + """ + Triggers a metrics report. + """ + self._timer.update_last_triggered_step(self._iter_count) + if self.metric_values is not None: + self._add_metrics_fn(self.metric_values) + + def begin(self): + """ + Triggered before each epoch. + """ + self._timer.reset() + self._iter_count = 0 + return super(MetricsUpdateHook, self).begin() + + def before_run(self, run_context): + """ + Triggered before each step. + """ + self._should_trigger = self._timer.should_trigger_for_step(self._iter_count) + return super(MetricsUpdateHook, self).before_run(run_context) + + def after_run(self, run_context, run_values): + """ + Triggered after each step. + """ + if self._should_trigger: + self.report_metrics() + self._iter_count += 1 + return super(MetricsUpdateHook, self).after_run(run_context, run_values) + + def end(self, session): + """ + Triggered after each epoch. + """ + self.report_metrics() + return super(MetricsUpdateHook, self).end(session) + + +class EarlyStopDuration(tf.train.SessionRunHook): + """ + Hook that can be used to terminate a job (training or validation) after a certain duration. + The hook is fault tolerant, i.e., if a job is allotted 1 hour to run and fails after 45 minutes, + then it will only run for 15 minutes once restarted. + + Args: + max_duration: + A float. When this argument is defined, the job will automatically terminate after + `max_duration` seconds if it has not already compeleted. + + overwrite: + A boolean. If set to True, this hook will overwrite the file containing the elapsed time + since the beginning of the job. In a distributed setting, this will be used so only one + job writes to the file while all others will have read access. In a distributed setting, + if all executors have this parameter set to False, then it just means that the hook will + not be fault tolerant. When restarted, the job will restart the clock from 0. + + save_dir: + String. A directory (located on a file system that is Tensorflow compatible) where + we can store the file which contains the record of the elapsed time. This file is what makes + the hook faul tolerant. + + exit_on_end: + when exit_on_end is True, twml.errors.EarlyStopError() is triggered to stop the job. + This is usually set to True to kill a validation job in a distributed setting. + """ + + def __init__(self, max_duration: float, exit_on_end: bool, save_dir: str, overwrite: bool): + self._overwrite = overwrite + self._save_dir = save_dir + self._exit_on_end = exit_on_end + self._max_duration = max_duration + self._last_time_check = datetime.now() + + # Initialize elapse time file + if overwrite: + self.elapsed_time() + + @property + def elapsed_file_path(self): + return os.path.join(self._save_dir, "early_stop_duration.txt") + + def early_stop(self) -> bool: + return self.elapsed_time() > self._max_duration + + def elapsed_time(self) -> float: + # Recorded elapsed time is 0 unless it's been recorded in a file already + recorded_elapsed_time = 0 + if tf.io.gfile.exists(self.elapsed_file_path): + with tf.io.gfile.GFile(self.elapsed_file_path, mode="r") as file: + recorded_elapsed_time = json.loads(file.read())["elapsed_time"] + + elapsed_time = recorded_elapsed_time + (datetime.now() - self._last_time_check).total_seconds() + self._last_time_check = datetime.now() + + if self._overwrite: + # Record the actualized new elapsed time to the file + tf.io.gfile.makedirs(os.path.dirname(self.elapsed_file_path)) + with tf.io.gfile.GFile(self.elapsed_file_path, mode="w") as file: + record = { + "elapsed_time": elapsed_time, + "max_duration": self._max_duration + } + file.write(json.dumps(record, indent=2)) + + return elapsed_time + + def before_run(self, run_context: tf.estimator.SessionRunContext) -> None: + if self.early_stop(): + message = f""" + Stopping job which now exceeded the maximum duration of {self._max_duration} seconds. + """ + logging.info(message) + run_context.request_stop() + + if self._exit_on_end: + raise twml.errors.EarlyStopError(message) + + +class StopAtStepHook(tf.train.StopAtStepHook): + """ + Overrides ``tf.train.StopAtStepHook`` so that + a ``stop_requested`` property can be accessed to determine + if this hook requested a stop. + """ + + def __init__(self, *args, **kwargs): + super(StopAtStepHook, self).__init__(*args, **kwargs) + self._stop_requested = False + + @property + def stop_requested(self): + """ true if this hook requested a stop """ + return self._stop_requested + + def after_run(self, run_context, run_values): + """ sets self.stop_requested to true when requesting a stop """ + super(StopAtStepHook, self).after_run(run_context, run_values) + self._stop_requested = run_context.stop_requested + + +class StopIfExistsHook(tf.train.SessionRunHook): + """ + Hook that requests stop if a file exists. + This hook is used with the EarlyStopHook to implement + early-stopping for distributed training (tf.estimator.train_and_evaluate). + """ + + def __init__(self, file_path): + """ + Arguments: + file_path: + path to file. When this hook detects that the file exists, + it requests a stop, which effectively kills this worker. + """ + self._file_path = file_path + self._stop_requested = False + + def after_run(self, run_context, run_values): + if tf.io.gfile.exists(self._file_path): + logging.info("Early-stopping file detected; requesting stop") + run_context.request_stop() + self._stop_requested = True + + @property + def stop_requested(self): + """ true if this hook requested a stop """ + return self._stop_requested diff --git a/twml/twml/input_fns.py b/twml/twml/input_fns.py new file mode 100644 index 000000000..394fc8674 --- /dev/null +++ b/twml/twml/input_fns.py @@ -0,0 +1,129 @@ +''' +Contains implementations of functions to read input data. +''' +from .dataset import stream_block_format_dataset + +import tensorflow.compat.v1 as tf + + +def data_record_input_fn( + files, batch_size, parse_fn, + num_threads=2, repeat=False, dataset_fn=None, + keep_rate=None, parts_downsampling_rate=None, + shards=None, shard_index=None, shuffle=True, shuffle_files=True, interleave=True, + initializable=False, log_tf_data_summaries=False, + **kwargs): + """ + Returns a nested structure of tf.Tensors containing the next element. + Used by ``train_input_fn`` and ``eval_input_fn`` in DataRecordTrainer. + By default, works with DataRecord dataset for compressed partition files. + + Args: + files: + List of files that will be parsed. + batch_size: + number of samples per batch. + parse_fn: + function passed to data loading for parsing individual data records. + Usually one of the decoder functions like ``parsers.get_sparse_parse_fn``. + num_threads (optional): + number of threads used for loading data. Defaults to 2. + repeat (optional): + Repeat the dataset indefinitely. Defaults to False. + Useful when you want to use ``train_steps`` or ``eval_steps`` + greater than the size of the dataset + (otherwise Estimator.[train,evaluate] stops when the end of the dataset is reached). + dataset_fn (optional): + A function that modifies the dataset after it reads different interleaved parts files. + Defaults to: + + .. code-block:: python + + def dataset_fn(dataset, parse_fn, batch_size): + return dataset.batch(batch_size).map(parse_fn, 1) + + keep_rate (optional): + A float value in (0.0, 1.0] that indicates to drop records according to the Bernoulli + distribution with p = 1 - keep_rate. + Defaults to None (no records dropped). + + parts_downsampling_rate (optional): + A float value in (0.0, 1.0] that indicates the factor by which to downsample part files. + For example, a value of 0.2 means only 20 percent of part files become part of the dataset. + + shards (optional): + Number of partitions to shard the dataset into. This is useful for codistillation + (https://arxiv.org/pdf/1804.03235.pdf) and other techniques that require each worker to + train on disjoint partitions of the dataset. + The dataset is not sharded by default. + + shard_index (optional): + Which partition of the dataset to use if ``shards`` is set. + + shuffle (optional): + Whether to shuffle the records. Defaults to True. + + shuffle_files (optional): + Shuffle the list of files. Defaults to True. + When False, files are iterated in the order they are passed in. + + interleave (optional): + Interleave records from multiple files in parallel. Defaults to True. + + initializable (optional): + A boolean indicator. When the Dataset Iterator depends on some resource, e.g. a HashTable or + a Tensor, i.e. it's an initializable iterator, set it to True. Otherwise, default value (false) + is used for most plain iterators. + + log_tf_data_summaries (optional): + A boolean indicator denoting whether to add a `tf.data.experimental.StatsAggregator` to the + tf.data pipeline. This adds summaries of pipeline utilization and buffer sizes to the output + events files. This requires that `initializable` is `True` above. + + Returns: + Iterator of elements of the dataset. + """ + if not parse_fn: + raise ValueError("default_input_fn requires a parse_fn") + + if log_tf_data_summaries and not initializable: + raise ValueError("Require `initializable` if `log_tf_data_summaries`.") + + dataset = stream_block_format_dataset( + files=files, + parse_fn=parse_fn, + batch_size=batch_size, + repeat=repeat, + num_threads=num_threads, + dataset_fn=dataset_fn, + keep_rate=keep_rate, + parts_downsampling_rate=parts_downsampling_rate, + shards=shards, + shard_index=shard_index, + shuffle=shuffle, + shuffle_files=shuffle_files, + interleave=interleave, + **kwargs + ) + + # Add a tf.data.experimental.StatsAggregator + # https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/data/experimental/StatsAggregator + if log_tf_data_summaries: + aggregator = tf.data.experimental.StatsAggregator() + options = tf.data.Options() + options.experimental_stats.aggregator = aggregator + dataset = dataset.with_options(options) + stats_summary = aggregator.get_summary() + tf.add_to_collection(tf.GraphKeys.SUMMARIES, stats_summary) + + if initializable: + # when the data parsing dpends on some HashTable or Tensor, the iterator is initalizable and + # therefore we need to be run explicitly + iterator = dataset.make_initializable_iterator() + tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer) + else: + iterator = dataset.make_one_shot_iterator() + return iterator.get_next() + + +default_input_fn = data_record_input_fn # pylint: disable=invalid-name diff --git a/twml/twml/layers/__init__.py b/twml/twml/layers/__init__.py new file mode 100644 index 000000000..917c61867 --- /dev/null +++ b/twml/twml/layers/__init__.py @@ -0,0 +1,21 @@ +# pylint: disable=wildcard-import +""" +This module contains the ``tf.layers.Layer`` subclasses implemented in twml. +Layers are used to instantiate common subgraphs. +Typically, these layers are used when defining a ``build_graph_fn`` +for the ``twml.trainers.Trainer``. +""" + +from .batch_prediction_tensor_writer import BatchPredictionTensorWriter # noqa: F401 +from .batch_prediction_writer import BatchPredictionWriter # noqa: F401 +from .data_record_tensor_writer import DataRecordTensorWriter # noqa: F401 +from .full_dense import full_dense, FullDense # noqa: F401 +from .full_sparse import full_sparse, FullSparse # noqa: F401 +from .isotonic import Isotonic # noqa: F401 +from .layer import Layer # noqa: F401 +from .mdl import MDL # noqa: F401 +from .partition import Partition # noqa: F401 +from .percentile_discretizer import PercentileDiscretizer # noqa: F401 +from .sequential import Sequential # noqa: F401 +from .sparse_max_norm import MaxNorm, sparse_max_norm, SparseMaxNorm # noqa: F401 +from .stitch import Stitch # noqa: F401 diff --git a/twml/twml/layers/batch_prediction_tensor_writer.py b/twml/twml/layers/batch_prediction_tensor_writer.py new file mode 100644 index 000000000..3f6633a8e --- /dev/null +++ b/twml/twml/layers/batch_prediction_tensor_writer.py @@ -0,0 +1,51 @@ +# pylint: disable=no-member, invalid-name +""" +Implementing Writer Layer +""" +from .layer import Layer + +import libtwml + + +class BatchPredictionTensorWriter(Layer): + """ + A layer that packages keys and dense tensors into a BatchPredictionResponse. + Typically used at the out of an exported model for use in a the PredictionEngine + (that is, in production) when model predictions are dense tensors. + + Arguments: + keys: + keys to hashmap + Output: + output: + a BatchPredictionResponse serialized using Thrift into a uint8 tensor. + """ + + def __init__(self, keys, **kwargs): # pylint: disable=useless-super-delegation + super(BatchPredictionTensorWriter, self).__init__(**kwargs) + self.keys = keys + + def compute_output_shape(self, input_shape): + """Computes the output shape of the layer given the input shape. + + Args: + input_shape: A (possibly nested tuple of) `TensorShape`. It need not + be fully defined (e.g. the batch size may be unknown). + + Raise NotImplementedError. + + """ + raise NotImplementedError + + def call(self, values, **kwargs): # pylint: disable=unused-argument, arguments-differ + """The logic of the layer lives here. + + Arguments: + values: + dense tensors corresponding to keys in hashmap + + Returns: + The output from the layer + """ + write_op = libtwml.ops.batch_prediction_tensor_response_writer(self.keys, values) + return write_op diff --git a/twml/twml/layers/batch_prediction_writer.py b/twml/twml/layers/batch_prediction_writer.py new file mode 100644 index 000000000..118d21921 --- /dev/null +++ b/twml/twml/layers/batch_prediction_writer.py @@ -0,0 +1,51 @@ +# pylint: disable=no-member, invalid-name +""" +Implementing Writer Layer +""" +from .layer import Layer + +import libtwml + + +class BatchPredictionWriter(Layer): + """ + A layer that packages keys and values into a BatchPredictionResponse. + Typically used at the out of an exported model for use in a the PredictionEngine + (that is, in production). + + Arguments: + keys: + keys to hashmap + Output: + output: + a BatchPredictionResponse serialized using Thrift into a uint8 tensor. + """ + + def __init__(self, keys, **kwargs): # pylint: disable=useless-super-delegation + super(BatchPredictionWriter, self).__init__(**kwargs) + self.keys = keys + + def compute_output_shape(self, input_shape): + """Computes the output shape of the layer given the input shape. + + Args: + input_shape: A (possibly nested tuple of) `TensorShape`. It need not + be fully defined (e.g. the batch size may be unknown). + + Raise NotImplementedError. + + """ + raise NotImplementedError + + def call(self, values, **kwargs): # pylint: disable=unused-argument, arguments-differ + """The logic of the layer lives here. + + Arguments: + values: + values corresponding to keys in hashmap + + Returns: + The output from the layer + """ + write_op = libtwml.ops.batch_prediction_response_writer(self.keys, values) + return write_op diff --git a/twml/twml/layers/data_record_tensor_writer.py b/twml/twml/layers/data_record_tensor_writer.py new file mode 100644 index 000000000..0f70186b4 --- /dev/null +++ b/twml/twml/layers/data_record_tensor_writer.py @@ -0,0 +1,50 @@ +# pylint: disable=no-member, invalid-name +""" +Implementing Writer Layer +""" +from .layer import Layer + +import libtwml + + +class DataRecordTensorWriter(Layer): + """ + A layer that packages keys and dense tensors into a DataRecord. + This layer was initially added to support exporting user embeddings as tensors. + + Arguments: + keys: + keys to hashmap + Output: + output: + a DataRecord serialized using Thrift into a uint8 tensor + """ + + def __init__(self, keys, **kwargs): # pylint: disable=useless-super-delegation + super(DataRecordTensorWriter, self).__init__(**kwargs) + self.keys = keys + + def compute_output_shape(self, input_shape): + """Computes the output shape of the layer given the input shape. + + Args: + input_shape: A (possibly nested tuple of) `TensorShape`. It need not + be fully defined (e.g. the batch size may be unknown). + + Raises NotImplementedError. + + """ + raise NotImplementedError + + def call(self, values, **kwargs): # pylint: disable=unused-argument, arguments-differ + """The logic of the layer lives here. + + Arguments: + values: + dense tensors corresponding to keys in hashmap + + Returns: + The output from the layer + """ + write_op = libtwml.ops.data_record_tensor_writer(self.keys, values) + return write_op diff --git a/twml/twml/layers/full_dense.py b/twml/twml/layers/full_dense.py new file mode 100644 index 000000000..9c354ad3e --- /dev/null +++ b/twml/twml/layers/full_dense.py @@ -0,0 +1,259 @@ +# pylint: disable=no-member,arguments-differ, attribute-defined-outside-init +""" +Implementing Full Dense Layer +""" +from tensorflow.python.layers import core as core_layers +from tensorflow.python.ops import init_ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras.engine.base_layer import InputSpec +import tensorflow.compat.v1 as tf + + +class FullDense(core_layers.Dense): + """ + Densely-connected layer class. + This is wrapping tensorflow.python.layers.core.Dense + This layer implements the operation: + + .. code-block:: python + + outputs = activation(inputs.weight + bias) + + Where ``activation`` is the activation function passed as the ``activation`` + argument (if not ``None``), ``weight`` is a weights matrix created by the layer, + and ``bias`` is a bias vector created by the layer. + + Arguments: + output_size: + Integer or Long, dimensionality of the output space. + activation: + Activation function (callable). Set it to None to maintain a linear activation. + weight_initializer: + Initializer function for the weight matrix. + bias_initializer: + Initializer function for the bias. + weight_regularizer: + Regularizer function for the weight matrix. + Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect. + bias_regularizer: + Regularizer function for the bias. + Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect. + activity_regularizer: + Regularizer function for the output. + weight_constraint: + An optional projection function to be applied to the + weight after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + bias_constraint: + An optional projection function to be applied to the + bias after being updated by an `Optimizer`. + trainable: + Boolean, if `True` also add variables to the graph collection + ``GraphKeys.TRAINABLE_VARIABLES`` (see `tf.Variable + `_). + name: + String, the name of the layer. Layers with the same name will + share weights, but to avoid mistakes we require ``reuse=True`` in such cases. + + Properties: + output_size: + Python integer, dimensionality of the output space. + activation: + Activation function (callable). + weight_initializer: + Initializer instance (or name) for the weight matrix. + bias_initializer: + Initializer instance (or name) for the bias. + weight: + Weight matrix (TensorFlow variable or tensor). (weight) + bias: + Bias vector, if applicable (TensorFlow variable or tensor). + weight_regularizer: + Regularizer instance for the weight matrix (callable) + bias_regularizer: + Regularizer instance for the bias (callable). + activity_regularizer: + Regularizer instance for the output (callable) + weight_constraint: + Constraint function for the weight matrix. + bias_constraint: + Constraint function for the bias. + + """ + + def __init__(self, output_size, + weight_initializer=None, + weight_regularizer=None, + weight_constraint=None, + bias_constraint=None, + num_partitions=None, + **kwargs): + super(FullDense, self).__init__(units=output_size, + kernel_initializer=weight_initializer, + kernel_regularizer=weight_regularizer, + kernel_constraint=weight_constraint, + **kwargs) + self._num_partitions = num_partitions + + def build(self, input_shape): + ''' + code adapted from TF 1.12 Keras Dense layer: + https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/python/keras/layers/core.py#L930-L956 + ''' + input_shape = tensor_shape.TensorShape(input_shape) + if input_shape[-1] is None: + raise ValueError('The last dimension of the inputs to `Dense` ' + 'should be defined. Found `None`.') + self.input_spec = InputSpec(min_ndim=2, + axes={-1: input_shape[-1]}) + + partitioner = None + if self._num_partitions: + partitioner = tf.fixed_size_partitioner(self._num_partitions) + + self.kernel = self.add_weight( + 'kernel', + shape=[input_shape[-1], self.units], + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + dtype=self.dtype, + partitioner=partitioner, + trainable=True) + + if self.use_bias: + self.bias = self.add_weight( + 'bias', + shape=[self.units, ], + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + dtype=self.dtype, + trainable=True) + else: + self.bias = None + self.built = True + + @property + def output_size(self): + """ + Returns output_size + """ + return self.units + + @property + def weight(self): + """ + Returns weight + """ + return self.kernel + + @property + def weight_regularizer(self): + """ + Returns weight_regularizer + """ + return self.kernel_regularizer + + @property + def weight_initializer(self): + """ + Returns weight_initializer + """ + return self.kernel_initializer + + @property + def weight_constraint(self): + """ + Returns weight_constraint + """ + return self.kernel_constraint + + +def full_dense(inputs, output_size, + activation=None, + use_bias=True, + weight_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + weight_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + weight_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + num_partitions=None, + reuse=None): + """Functional interface for the densely-connected layer. + This layer implements the operation: + `outputs = activation(inputs.weight + bias)` + Where `activation` is the activation function passed as the `activation` + argument (if not `None`), `weight` is a weights matrix created by the layer, + and `bias` is a bias vector created by the layer + (only if `use_bias` is `True`). + + Arguments: + inputs: Tensor input. + units: Integer or Long, dimensionality of the output space. + activation: Activation function (callable). Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + weight_initializer: Initializer function for the weight matrix. + If `None` (default), weights are initialized using the default + initializer used by `tf.get_variable`. + bias_initializer: + Initializer function for the bias. + weight_regularizer: + Regularizer function for the weight matrix. + Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect. + bias_regularizer: + Regularizer function for the bias. + Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect. + activity_regularizer: + Regularizer function for the output. + weight_constraint: + An optional projection function to be applied to the + weight after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + bias_constraint: + An optional projection function to be applied to the + bias after being updated by an `Optimizer`. + trainable: + Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: + String, the name of the layer. + reuse: + Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor the same shape as `inputs` except the last dimension is of + size `units`. + + Raises: + ValueError: if eager execution is enabled. + """ + layer = FullDense(output_size, + activation=activation, + use_bias=use_bias, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + weight_regularizer=weight_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + weight_constraint=weight_constraint, + bias_constraint=bias_constraint, + trainable=trainable, + name=name, + dtype=inputs.dtype.base_dtype, + num_partitions=num_partitions, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) diff --git a/twml/twml/layers/full_sparse.py b/twml/twml/layers/full_sparse.py new file mode 100644 index 000000000..4f0f21930 --- /dev/null +++ b/twml/twml/layers/full_sparse.py @@ -0,0 +1,370 @@ +# pylint: disable=no-member, arguments-differ, attribute-defined-outside-init, unused-argument +""" +Implementing Full Sparse Layer +""" + +import math + +from twitter.deepbird.sparse import sparse_dense_matmul + +from .layer import Layer + +import tensorflow.compat.v1 as tf +import twml + + +class FullSparse(Layer): + """Fully-sparse layer class. + This layer implements the operation: + + .. code-block:: python + + outputs = activation(inputs.weight + bias) + + Arguments: + output_size: + Long or Integer, dimensionality of the output space. + input_size: + The number of input units. (Deprecated) + weight_initializer: + Initializer function for the weight matrix. + This argument defaults to zeros_initializer(). + This is valid when the FullSparse is the first layer of + parameters but should be changed otherwise. + weight_regularizer: + Regularizer function for the weight matrix. + Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect. + bias_regularizer: + Regularizer function for the bias. + Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect + activation: + Activation function (callable). Set it to None to maintain a linear activation. + bias_initializer: + Initializer function for the bias. + This argument defaults to tf.constant_initializer(1/output_size) + trainable: + Boolean, if `True` also add variables to the graph collection + ``GraphKeys.TRAINABLE_VARIABLES`` (see `tf.Variable + `_). + name: + String, the name of the layer. Layers with the same name will + share weights, but to avoid mistakes we require ``reuse=True`` in such cases. + use_sparse_grads: + Boolean, if `True` do sparse mat mul with `embedding_lookup_sparse`, which will + make gradients to weight matrix also sparse in backward pass. This can lead to non-trivial + speed up at training time when input_size is large and optimizer handles sparse gradients + correctly (eg. with SGD or LazyAdamOptimizer). If weight matrix is small, it's recommended + to set this flag to `False`; for most use cases of FullSparse, however, weight matrix will + be large, so it's better to set it to `True` + num_partitions: + Number of partitions to use for the weight variable. Defaults to 1. + partition_axis: + If num_partitions is specified, the partition axis for the weight variable + Defaults to 0 (partition by row). + Must be 0 (row) or 1 (column) + use_binary_values: + Assume all non zero values are 1. Defaults to False. + This can improve training if used in conjunction with MDL. + This parameter can also be a list of binary values if `inputs` passed to `call` a list. + use_compression: + Default False. Set True to enable data compression techniques for + optimization of network traffic for distributed training. + use_binary_sparse_dense_matmul: + If binary sparse dense matmul op is to be used. It will only be enabled if + `use_binary_values` is set true. It only should be used for inference, best practice is + to set `use_binary_sparse_dense_matmul = not is_training`. + """ + + def __init__(self, + output_size, + input_size=None, + weight_initializer=None, + activation=None, + bias_initializer=None, + trainable=True, + name=None, + use_sparse_grads=True, + num_partitions=None, + partition_axis=0, + use_binary_values=False, + bias_regularizer=None, + weight_regularizer=None, + use_compression=False, + use_binary_sparse_dense_matmul=False, + **kwargs): + super(FullSparse, self).__init__(trainable=trainable, name=name, **kwargs) + # TODO - remove input_size warning. + if input_size: + raise ValueError('input_size is deprecated - it is now automatically \ + inferred from your input.') + + # The bias initialization and weights initialization is set to match v1's implementation. + if bias_initializer is None: + bias_initializer = tf.constant_initializer(1 / output_size) + # Weights initialization is set to 0s. This is safe for full sparse layers because + # you are supposed to learn your embedding from the label. + if weight_initializer is None: + weight_initializer = tf.zeros_initializer() + self.weight_initializer = weight_initializer + self.bias_initializer = bias_initializer + self.output_size = output_size + self.activation = activation + self.use_sparse_grads = use_sparse_grads + self.num_partitions = num_partitions + if partition_axis != 0 and partition_axis != 1: + raise ValueError('partition_axis must be 0 or 1') + self.partition_axis = partition_axis + self.use_binary_values = use_binary_values + self.weight_regularizer = weight_regularizer + self.bias_regularizer = bias_regularizer + self._use_compression = use_compression + self._cast_indices_dtype = tf.int32 if self._use_compression else None + self.use_binary_sparse_dense_matmul = use_binary_sparse_dense_matmul + + def _make_weight_var(self, shape, partitioner): + self.weight = self.add_variable( + 'weight', + initializer=self.weight_initializer, + regularizer=self.weight_regularizer, + shape=shape, + dtype=self.dtype, + trainable=True, + partitioner=partitioner, + ) + + def build(self, input_shapes): + """ + creates the ``bias`` and ``weight`` Variables + of shape ``[output_size]`` and ``[input_size, output_size]`` respectively. + """ + + if isinstance(input_shapes, (list, tuple)): + input_shape = input_shapes[0] + is_compatible = True + for other_shape in input_shapes[1:]: + is_compatible &= input_shape.is_compatible_with(other_shape) + if not is_compatible: + raise ValueError("Input shapes %s are not compatible." % input_shapes) + else: + input_shape = input_shapes + + self.bias = self.add_variable( + 'bias', + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + shape=[self.output_size, ], + dtype=self.dtype, + trainable=True + ) + + partitioner = None + shape = [input_shape[1], self.output_size] + + # There is a 2gb limitation for each tensor because of protobuf. + # 2**30 is 1GB. 2 * (2**30) is 2GB. + dtype = tf.as_dtype(self.dtype) + num_partitions = 1 if self.num_partitions is None else self.num_partitions + in_shape = input_shape[1] + out_shape = self.output_size + + # when v2 behavior is disabled, in_shape is tf.Dimension. otherwise it is int. + if isinstance(in_shape, tf.Dimension): + in_shape = in_shape.value + + if in_shape is None: + raise ValueError("Input tensor should have shape." + " You can set it using twml.util.limit_sparse_tensor_size") + + (split_dim, other_dim) = (in_shape, out_shape) if self.partition_axis == 0 else (out_shape, in_shape) + requested_size = math.ceil(float(split_dim) / num_partitions) * other_dim * dtype.size + if (requested_size >= 2**31): + raise ValueError("Weight tensor partitions cannot be larger than 2GB.\n" + "Requested Dimensions(%d, %d) of type %s (%d bytes total) over %d partitions.\n" + "Possible solutions:\n" + "- reduce the params.output_size_bits\n" + "- reduce the output_size of the sparse_layer\n" + "- specify a larger num_partitions argument\n" + "- reduce input_size_bits" % + (in_shape, self.output_size, dtype.name, requested_size, num_partitions)) + + if self.num_partitions: + partition_axis = int(self.partition_axis) + partitioner = tf.fixed_size_partitioner(self.num_partitions, axis=partition_axis) + else: + # Regular variables do not like it when you pass both constant tensors and shape + if not callable(self.weight_initializer): + shape = None + + self._make_weight_var(shape, partitioner) + + self.built = True + + def compute_output_shape(self, input_shape): + """Computes the output shape of the layer given the input shape. + + Args: + input_shape: A (possibly nested tuple of) `TensorShape`. It need not + be fully defined (e.g. the batch size may be unknown). + + Raises NotImplementedError. + + """ + raise NotImplementedError + + def call(self, inputs, **kwargs): # pylint: disable=unused-argument + """The logic of the layer lives here. + + Arguments: + inputs: + A SparseTensor or a list of SparseTensors. + If `inputs` is a list, all tensors must have same `dense_shape`. + + Returns: + - If `inputs` is `SparseTensor`, then returns `bias + inputs * dense_b`. + - If `inputs` is a `list[SparseTensor`, then returns + `bias + add_n([sp_a * dense_b for sp_a in inputs])`. + + """ + if isinstance(inputs, (list, tuple)): + + if isinstance(self.use_binary_values, (list, tuple)): + use_binary_values = self.use_binary_values + else: + use_binary_values = [self.use_binary_values] * len(inputs) + + num_inputs = len(inputs) + if num_inputs != len(use_binary_values): + raise ValueError("#inputs is %d while #use_binary_values is %d" + % (num_inputs, len(use_binary_values))) + + outputs = [] + for n in range(num_inputs): + outputs.append(sparse_dense_matmul(inputs[n], self.weight, + self.use_sparse_grads, + use_binary_values[n], + name='sparse_mm_' + str(n), + partition_axis=self.partition_axis, + num_partitions=self.num_partitions, + compress_ids=self._use_compression, + cast_indices_dtype=self._cast_indices_dtype, + use_binary_sparse_dense_matmul=self.use_binary_sparse_dense_matmul)) + outputs = tf.accumulate_n(outputs) + else: + + if isinstance(self.use_binary_values, (list, tuple)): + raise ValueError("use_binary_values can not be %s when inputs is %s" % + (type(self.use_binary_values), type(inputs))) + + outputs = sparse_dense_matmul(inputs, self.weight, + self.use_sparse_grads, + self.use_binary_values, + name='sparse_mm', + partition_axis=self.partition_axis, + num_partitions=self.num_partitions, + compress_ids=self._use_compression, + cast_indices_dtype=self._cast_indices_dtype, + use_binary_sparse_dense_matmul=self.use_binary_sparse_dense_matmul) + + if self.bias is not None: + outputs = tf.nn.bias_add(outputs, self.bias) + + if self.activation is not None: + return self.activation(outputs) # pylint: disable=not-callable + return outputs + + +def full_sparse( + inputs, output_size, + input_size=None, + activation=None, + bias_regularizer=None, + weight_regularizer=None, + bias_initializer=None, + weight_initializer=None, + trainable=True, + name=None, + reuse=None, + use_sparse_grads=True, + num_partitions=None, + partition_axis=0, + use_binary_values=False, + use_compression=False): + """Functional interface for the sparsely-connected layer. + + Arguments: + inputs: + A sparse tensor (can be twml.SparseTensor or tf.SparseTensor) + output_size: + Long or Integer, dimensionality of the output space. + weight_initializer: + Initializer function for the weight matrix. + activation: + Activation function (callable). Set it to None to maintain a linear activation. + bias_initializer: + Initializer function for the bias. + weight_regularizer: + Regularizer function for the weight matrix. + Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect. + bias_regularizer: + Regularizer function for the bias. + Ensure to add tf.losses.get_regularization_loss() to your loss for this to take effect. + trainable: + Boolean, if `True` also add variables to the graph collection + ``GraphKeys.TRAINABLE_VARIABLES`` (see `tf.Variable + `_). + name: + String, the name of the layer. Layers with the same name will + share weights, but to avoid mistakes we require ``reuse=True`` in such cases. + use_sparse_grads: + Boolean, if `True` do sparse mat mul with `embedding_lookup_sparse`, which will + make gradients to weight matrix also sparse in backward pass. This can lead to non-trivial + speed up at training time when input_size is large and optimizer handles sparse gradients + correctly (eg. with SGD or LazyAdamOptimizer). If weight matrix is small, it's recommended + to set this flag to `False`; for most use cases of FullSparse, however, weight matrix will + be large, so it's better to set it to `True` + num_partitions: + Number of partitions to use for the weight variable. Defaults to 1. + partition_axis: + If num_partitions is specified, the partition axis for the weight variable + Defaults to 0 (partition by row). + Must be 0 (row) or 1 (column) + use_binary_values: + Assume all non zero values are 1. Defaults to False. + This can improve training if used in conjunction with MDL. + use_compression: + Default False. Set True to enable data compression techniques for + optimization of network traffic for distributed training. + Returns: + Outputs a ``tf.Tensor`` of size ``[batch_size x output_size]``. + """ + # TODO - remove input_size warning. + if input_size: + raise ValueError('input_size is deprecated - it is now \ + automatically inferred from your input.') + + dtype = None + if isinstance(inputs, twml.SparseTensor): + inputs = inputs.to_tf() + dtype = inputs.dtype.base_dtype + + if isinstance(inputs, (list, tuple)): + inputs = [inp.to_tf() if isinstance(inp, twml.SparseTensor) else inp for inp in inputs] + dtype = inputs[0].dtype.base_dtype + + layer = FullSparse(output_size=output_size, + activation=activation, + trainable=trainable, + name=name, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + weight_regularizer=weight_regularizer, + bias_regularizer=bias_regularizer, + dtype=dtype, + _scope=name, + _reuse=reuse, + use_sparse_grads=use_sparse_grads, + num_partitions=num_partitions, + partition_axis=partition_axis, + use_compression=use_compression, + use_binary_values=use_binary_values) + return layer(inputs) diff --git a/twml/twml/layers/isotonic.py b/twml/twml/layers/isotonic.py new file mode 100644 index 000000000..7113f7af4 --- /dev/null +++ b/twml/twml/layers/isotonic.py @@ -0,0 +1,76 @@ +# pylint: disable=no-member, invalid-name, attribute-defined-outside-init +""" +Contains the Isotonic Layer +""" + +from .layer import Layer + +import libtwml +import numpy as np + + +class Isotonic(Layer): + """ + This layer is created by the IsotonicCalibrator. + Typically it is used intead of sigmoid activation on the output unit. + + Arguments: + n_unit: + number of input units to the layer (same as number of output units). + n_bin: + number of bins used for isotonic calibration. + More bins means a more precise isotonic function. + Less bins means a more regularized isotonic function. + xs_input: + A tensor containing the boundaries of the bins. + ys_input: + A tensor containing calibrated values for the corresponding bins. + + Output: + output: + A layer containing calibrated probabilities with same shape and size as input. + Expected Sizes: + xs_input, ys_input: + [n_unit, n_bin]. + Expected Types: + xs_input, ys_input: + same as input. + """ + + def __init__(self, n_unit, n_bin, xs_input=None, ys_input=None, **kwargs): + super(Isotonic, self).__init__(**kwargs) + + self._n_unit = n_unit + self._n_bin = n_bin + + self.xs_input = np.empty([n_unit, n_bin], dtype=np.float32) if xs_input is None else xs_input + self.ys_input = np.empty([n_unit, n_bin], dtype=np.float32) if ys_input is None else ys_input + + def compute_output_shape(self, input_shape): + """Computes the output shape of the layer given the input shape. + + Args: + input_shape: A (possibly nested tuple of) `TensorShape`. It need not + be fully defined (e.g. the batch size may be unknown). + + Raises NotImplementedError. + + """ + raise NotImplementedError + + def build(self, input_shape): # pylint: disable=unused-argument + """Creates the variables of the layer.""" + + self.built = True + + def call(self, inputs, **kwargs): # pylint: disable=unused-argument + """The logic of the layer lives here. + + Arguments: + inputs: input tensor(s). + + Returns: + The output from the layer + """ + calibrate_op = libtwml.ops.isotonic_calibration(inputs, self.xs_input, self.ys_input) + return calibrate_op diff --git a/twml/twml/layers/layer.py b/twml/twml/layers/layer.py new file mode 100644 index 000000000..c1b00eb13 --- /dev/null +++ b/twml/twml/layers/layer.py @@ -0,0 +1,50 @@ +# pylint: disable=no-member +""" +Implementing a base layer for twml +""" +import tensorflow.compat.v1 as tf +from tensorflow.python.layers import base + + +class Layer(base.Layer): + """ + Base Layer implementation for twml. + Overloads `twml.layers.Layer + `_ + from tensorflow and adds a couple of custom methods. + """ + + @property + def init(self): + """ + Return initializer ops. By default returns tf.no_op(). + This method is overwritten by classes like twml.layers.MDL, which + uses a HashTable internally, that must be initialized with its own op. + """ + return tf.no_op() + + def call(self, inputs, **kwargs): + """The logic of the layer lives here. + + Arguments: + inputs: + input tensor(s). + **kwargs: + additional keyword arguments. + + Returns: + Output tensor(s). + """ + raise NotImplementedError + + def compute_output_shape(self, input_shape): + """Computes the output shape of the layer given the input shape. + + Args: + input_shape: A (possibly nested tuple of) `TensorShape`. It need not + be fully defined (e.g. the batch size may be unknown). + + Raise NotImplementedError. + + """ + raise NotImplementedError diff --git a/twml/twml/layers/mdl.py b/twml/twml/layers/mdl.py new file mode 100644 index 000000000..cf4018afa --- /dev/null +++ b/twml/twml/layers/mdl.py @@ -0,0 +1,256 @@ +# pylint: disable=no-member, attribute-defined-outside-init, too-many-instance-attributes +""" +Implementing MDL Layer +""" + + +from .layer import Layer +from .partition import Partition +from .stitch import Stitch + +import libtwml +import numpy as np +import tensorflow.compat.v1 as tf +import twml + + +class MDL(Layer): # noqa: T000 + """ + MDL layer is constructed by MDLCalibrator after accumulating data + and performing minimum description length (MDL) calibration. + + MDL takes sparse continuous features and converts then to sparse + binary features. Each binary output feature is associated to an MDL bin. + Each MDL input feature is converted to n_bin bins. + Each MDL calibration tries to find bin delimiters such that the number of features values + per bin is roughly equal (for each given MDL feature). + Note that if an input feature is rarely used, so will its associated output bin/features. + """ + + def __init__( + self, + n_feature, n_bin, out_bits, + bin_values=None, hash_keys=None, hash_values=None, + bin_ids=None, feature_offsets=None, **kwargs): + """ + Creates a non-initialized `MDL` object. + Before using the table you will have to initialize it. After initialization + the table will be immutable. + + Parent class args: + see [tf.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/layers/Layer) + for documentation of parent class arguments. + + Required args: + n_feature: + number of unique features accumulated during MDL calibration. + This is the number of features in the hash map. + Used to initialize bin_values, hash_keys, hash_values, + bin_ids, bin_values and feature_offsets. + n_bin: + number of MDL bins used for MDL calibration. + Used to initialize bin_values, hash_keys, hash_values, + bin_ids, bin_values and feature_offsets. + out_bits: + Determines the maximum value for output feature IDs. + The dense_shape of the SparseTensor returned by lookup(x) + will be [x.shape[0], 1 << output_bits]. + + Optional args: + hash_keys: + contains the features ID that MDL discretizes and knows about. + The hash map (hash_keys->hash_values) is used for two reasons: + 1. divide inputs into two feature spaces: MDL vs non-MDL + 2. transate the MDL features into a hash_feature ID that MDL understands. + The hash_map is expected to contain n_feature items. + hash_values: + translates the feature IDs into hash_feature IDs for MDL. + bin_ids: + a 1D Tensor of size n_feature * n_bin + 1 which contains + unique IDs to which the MDL features will be translated to. + For example, tf.Tensor(np.arange(n_feature * n_bin)) would produce + the most efficient output space. + bin_values: + a 1D Tensor aligned with bin_ids. + For a given hash_feature ID j, it's value bin's are indexed between + `j*n_bin` and `j*n_bin + n_bin-1`. + As such, bin_ids[j*n_bin+i] is translated from a hash_feature ID of j + and a inputs value between + `bin_values[j*n_bin + i]` and `bin_values[j*n_bin+i+1]`. + feature_offsets: + a 1D Tensor specifying the starting location of bins for a given feature id. + For example, tf.Tensor(np.arange(0, bin_values.size, n_bin, dtype='int64')). + """ + super(MDL, self).__init__(**kwargs) + tf.logging.warning("MDL will be deprecated. Please use PercentileDiscretizer instead") + + max_mdl_feature = n_feature * (n_bin + 1) + self._n_feature = n_feature + self._n_bin = n_bin + + self._hash_keys_initializer = tf.constant_initializer( + hash_keys if hash_keys is not None + else np.empty(n_feature, dtype=np.int64), + dtype=np.int64 + ) + self._hash_values_initializer = tf.constant_initializer( + hash_values if hash_values is not None + else np.empty(n_feature, dtype=np.int64), + dtype=np.int64 + ) + self._bin_ids_initializer = tf.constant_initializer( + bin_ids if bin_ids is not None + else np.empty(max_mdl_feature, dtype=np.int64), + dtype=np.int64 + ) + self._bin_values_initializer = tf.constant_initializer( + bin_values if bin_values is not None + else np.empty(max_mdl_feature, dtype=np.float32), + dtype=np.float32 + ) + self._feature_offsets_initializer = tf.constant_initializer( + feature_offsets if feature_offsets is not None + else np.empty(n_feature, dtype=np.int64), + dtype=np.int64 + ) + + # note that calling build here is an exception as typically __call__ would call build(). + # We call it here because we need to initialize hash_map. + # Also note that the variable_scope is set by add_variable in build() + if not self.built: + self.build(input_shape=None) + + self.output_size = tf.convert_to_tensor(1 << out_bits, tf.int64) + + def build(self, input_shape): # pylint: disable=unused-argument + """ + Creates the variables of the layer: + hash_keys, hash_values, bin_ids, bin_values, feature_offsets and self.output_size. + """ + + # build layers + self.partition = Partition() + self.stitch = Stitch() + + # build variables + + hash_keys = self.add_variable( + 'hash_keys', + initializer=self._hash_keys_initializer, + shape=[self._n_feature], + dtype=tf.int64, + trainable=False) + + hash_values = self.add_variable( + 'hash_values', + initializer=self._hash_values_initializer, + shape=[self._n_feature], + dtype=tf.int64, + trainable=False) + + # hashmap converts known features into range [0, n_feature) + initializer = tf.lookup.KeyValueTensorInitializer(hash_keys, hash_values) + self.hash_map = tf.lookup.StaticHashTable(initializer, -1) + + self.bin_ids = self.add_variable( + 'bin_ids', + initializer=self._bin_ids_initializer, + shape=[self._n_feature * (self._n_bin + 1)], + dtype=tf.int64, + trainable=False) + + self.bin_values = self.add_variable( + 'bin_values', + initializer=self._bin_values_initializer, + shape=[self._n_feature * (self._n_bin + 1)], + dtype=tf.float32, + trainable=False) + + self.feature_offsets = self.add_variable( + 'feature_offsets', + initializer=self._feature_offsets_initializer, + shape=[self._n_feature], + dtype=tf.int64, + trainable=False) + + # make sure this is last + self.built = True + + def call(self, inputs, **kwargs): + """Looks up `keys` in a table, outputs the corresponding values. + + Implements MDL inference where inputs are intersected with a hash_map. + Part of the inputs are discretized using twml.mdl to produce a mdl_output SparseTensor. + This SparseTensor is then joined with the original inputs SparseTensor, + but only for the inputs keys that did not get discretized. + + Args: + inputs: A 2D SparseTensor that is input to MDL for discretization. + It has a dense_shape of [batch_size, input_size] + name: A name for the operation (optional). + Returns: + A `SparseTensor` of the same type as `inputs`. + Its dense_shape is [shape_input.dense_shape[0], 1 << output_bits]. + """ + if isinstance(inputs, tf.SparseTensor): + inputs = twml.SparseTensor.from_tf(inputs) + + assert(isinstance(inputs, twml.SparseTensor)) + + # sparse column indices + ids = inputs.ids + # sparse row indices + keys = inputs.indices + # sparse values + vals = inputs.values + + # get intersect(keys, hash_map) + hashed_keys = self.hash_map.lookup(keys) + + found = tf.not_equal(hashed_keys, tf.constant(-1, tf.int64)) + partition_ids = tf.cast(found, tf.int32) + + vals, key, indices = self.partition(partition_ids, vals, tf.where(found, hashed_keys, keys)) + non_mdl_keys, mdl_in_keys = key + non_mdl_vals, mdl_in_vals = vals + + self.non_mdl_keys = non_mdl_keys + + # run MDL on the keys/values it knows about + mdl_keys, mdl_vals = libtwml.ops.mdl(mdl_in_keys, mdl_in_vals, self.bin_ids, self.bin_values, + self.feature_offsets) + + # handle output ID conflicts + mdl_size = tf.size(self.bin_ids, out_type=tf.int64) + non_mdl_size = tf.subtract(self.output_size, mdl_size) + non_mdl_keys = tf.add(tf.floormod(non_mdl_keys, non_mdl_size), mdl_size) + + # Stitch the keys and values from mdl and non mdl indices back, with help + # of the Stitch Layer + + # out for inference checking + self.mdl_out_keys = mdl_keys + + concat_data = self.stitch([non_mdl_vals, mdl_vals], + [non_mdl_keys, mdl_keys], + indices) + + concat_vals, concat_keys = concat_data + + # Generate output shape using _compute_output_shape + + batch_size = tf.to_int64(inputs.dense_shape[0]) + output_shape = [batch_size, self.output_size] + return twml.SparseTensor(ids, concat_keys, concat_vals, output_shape).to_tf() + + def compute_output_shape(self, input_shape): + """Computes the output shape of the layer given the input shape. + + Args: + input_shape: A (possibly nested tuple of) `TensorShape`. It need not + be fully defined (e.g. the batch size may be unknown). + + Raises NotImplementedError. + + """ + raise NotImplementedError diff --git a/twml/twml/layers/partition.py b/twml/twml/layers/partition.py new file mode 100644 index 000000000..0e7c85f18 --- /dev/null +++ b/twml/twml/layers/partition.py @@ -0,0 +1,74 @@ +""" +Implementing partition Layer +""" + + +from .layer import Layer + +import tensorflow.compat.v1 as tf + + +class Partition(Layer): + """ + This layer implements: + + .. code-block:: python + + tf.dynamic_partition(input_vals, partition_ids, self.partitions) + + Input: + partitions: + the number of partitions which we will divide the hashmap keys/bvalues + + Output: + A layer that performs partitioning + """ + + def __init__(self, partitions=2, **kwargs): + self.partitions = partitions + super(Partition, self).__init__(**kwargs) + + def compute_output_shape(self, input_shape): + """Computes the output shape of the layer given the input shape. + + Args: + input_shape: A (possibly nested tuple of) `TensorShape`. It need not + be fully defined (e.g. the batch size may be unknown). + + Raises NotImplementedError. + + """ + raise NotImplementedError + + def call(self, partition_ids, input_vals, input_keys, **kwargs): + """This layer is responsible for partitioning the values/keys of a hashmap + + Arguments: + partition_ids: + Tensor that is equivalent to boolean (int32). + input_vals: + Tensor that represents the values of the hashmap(float). + input_keys: + Tensor that represents the keys of the hashmap(float) + + Returns: + The output of the partition layer, which is a list of lists which looks + something like: + + .. code-block:: python + + [[vals_0, vals_1], [keys_0, keys_1], [indices_0, indices_1]] + + where: + vals_x: + values of the hashmap for partition x + keys_x: + keys of the hashmap for partition x + indices_x: + indices of the hashmap for partition x + """ + partioned_val = tf.dynamic_partition(input_vals, partition_ids, self.partitions) + partioned_keys = tf.dynamic_partition(input_keys, partition_ids, self.partitions) + partioned_indices = tf.dynamic_partition(tf.range(tf.shape(partition_ids)[0]), + tf.cast(partition_ids, tf.int32), self.partitions) + return [partioned_val, partioned_keys, partioned_indices] diff --git a/twml/twml/layers/percentile_discretizer.py b/twml/twml/layers/percentile_discretizer.py new file mode 100644 index 000000000..55bb4de8c --- /dev/null +++ b/twml/twml/layers/percentile_discretizer.py @@ -0,0 +1,209 @@ +# pylint: disable=no-member, attribute-defined-outside-init, too-many-instance-attributes +""" +Implementing PercentileDiscretizer Layer +""" + + +import libtwml +import numpy as np +import tensorflow.compat.v1 as tf +import twml +from twml.layers import Layer + + +class PercentileDiscretizer(Layer): + """ + PercentileDiscretizer layer is constructed by PercentileDiscretizerCalibrator after + accumulating data and performing percentile bucket calibration. + + PercentileDiscretizer takes sparse continuous features and converts then to sparse + binary features. Each binary output feature is associated to an PercentileDiscretizer bin. + Each PercentileDiscretizer input feature is converted to n_bin bins. + Each PercentileDiscretizer calibration tries to find bin delimiters such + that the number of features values per bin is roughly equal (for + each given PercentileDiscretizer feature). In other words, bins are calibrated to be approx. + equiprobable, according to the given calibration data. + Note that if an input feature is rarely used, so will its associated output bin/features. + """ + + def __init__( + self, + n_feature, n_bin, out_bits, + bin_values=None, hash_keys=None, hash_values=None, + bin_ids=None, feature_offsets=None, num_parts=1, cost_per_unit=100, **kwargs): + """ + Creates a non-initialized `PercentileDiscretizer` object. + Before using the table you will have to initialize it. After initialization + the table will be immutable. + + If there are no calibrated features, then the discretizer will only apply + twml.util.limit_bits to the the feature keys (aka "feature_ids"). Essentially, + the discretizer will be a "no-operation", other than obeying `out_bits` + + Parent class args: + see [tf.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/layers/Layer) + for documentation of parent class arguments. + + Required args: + n_feature: + number of unique features accumulated during PercentileDiscretizer calibration. + This is the number of features in the hash map. + Used to initialize bin_values, hash_keys, hash_values, + bin_ids, bin_values and feature_offsets. + n_bin: + number of PercentileDiscretizer bins used for PercentileDiscretizer calibration. + Used to initialize bin_values, hash_keys, hash_values, + bin_ids, bin_values and feature_offsets. + out_bits: + Determines the maximum value for output feature IDs. + The dense_shape of the SparseTensor returned by lookup(x) + will be [x.shape[0], 1 << output_bits]. + + Optional args: + hash_keys: + contains the features ID that PercentileDiscretizer discretizes and knows about. + The hash map (hash_keys->hash_values) is used for two reasons: + 1. divide inputs into two feature spaces: + PercentileDiscretizer vs non-PercentileDiscretizer + 2. transate the PercentileDiscretizer features into a hash_feature ID that + PercentileDiscretizer understands. + The hash_map is expected to contain n_feature items. + hash_values: + translates the feature IDs into hash_feature IDs for PercentileDiscretizer. + bin_ids: + a 1D Tensor of size n_feature * n_bin + 1 which contains + unique IDs to which the PercentileDiscretizer features will be translated to. + For example, tf.Tensor(np.arange(n_feature * n_bin)) would produce + the most efficient output space. + bin_values: + a 1D Tensor aligned with bin_ids. + For a given hash_feature ID j, it's value bin's are indexed between + `j*n_bin` and `j*n_bin + n_bin-1`. + As such, bin_ids[j*n_bin+i] is translated from a hash_feature ID of j + and a inputs value between + `bin_values[j*n_bin + i]` and `bin_values[j*n_bin+i+1]`. + feature_offsets: + a 1D Tensor specifying the starting location of bins for a given feature id. + For example, tf.Tensor(np.arange(0, bin_values.size, n_bin, dtype='int64')). + """ + + super(PercentileDiscretizer, self).__init__(**kwargs) + + if not self.built: + self.build(input_shape=None) + + max_discretizer_feature = n_feature * (n_bin + 1) + self._n_feature = n_feature + self._n_bin = n_bin + + # build variables + self._out_bits = out_bits + self._output_size = tf.convert_to_tensor(1 << out_bits, tf.int64) + self._hash_keys = (hash_keys if hash_keys is not None else + np.empty(n_feature, dtype=np.int64)) + self._hash_values = (hash_values if hash_values is not None else + np.empty(n_feature, dtype=np.int64)) + self._bin_ids = (bin_ids if bin_ids is not None else + np.empty(max_discretizer_feature, dtype=np.int64)) + self._bin_values = (bin_values if bin_values is not None else + np.empty(max_discretizer_feature, dtype=np.float32)) + self._feature_offsets = (feature_offsets if feature_offsets is not None else + np.empty(n_feature, dtype=np.int64)) + self.num_parts = num_parts + self.cost_per_unit = cost_per_unit + + def build(self, input_shape): # pylint: disable=unused-argument + """ + Creates the variables of the layer + """ + self.built = True + + def call(self, inputs, keep_inputs=False, **kwargs): + """Looks up `keys` in a table, outputs the corresponding values. + + Implements PercentileDiscretizer inference where inputs are intersected with a hash_map. + Input features that were not calibrated have their feature IDs truncated, so as + to be less than 1< 0: + discretizer_keys, discretizer_vals = libtwml.ops.percentile_discretizer_v2( + input_ids=keys, # inc key assigned to feature_id, or -1 + input_vals=vals, # the observed feature values + bin_ids=self._bin_ids, # n_feat X (n_bin+1) 2D arange + bin_vals=self._bin_values, # bin boundaries + feature_offsets=self._feature_offsets, # 0 : nbin_1 : max_feat + output_bits=self._out_bits, + feature_ids=tf.make_tensor_proto(self._hash_keys), # feature ids to build internal hash map + feature_indices=tf.make_tensor_proto(self._hash_values), # keys associated w/ feat. indices + start_compute=tf.constant(0, shape=[], dtype=tf.int64), + end_compute=tf.constant(-1, shape=[], dtype=tf.int64), + cost_per_unit=self.cost_per_unit + ) + else: + discretizer_keys = twml.util.limit_bits(keys, self._out_bits) + discretizer_vals = vals + # don't 2x the input. + keep_inputs = False + + batch_size = tf.to_int64(inputs.dense_shape[0]) + output_shape = [batch_size, self._output_size] + + output = twml.SparseTensor(ids, discretizer_keys, discretizer_vals, output_shape).to_tf() + + if keep_inputs: + # Note the non-discretized features will end up doubled, + # since these are already in `output` + # handle output ID conflicts + mdl_size = self._n_feature * (self._n_bin + 1) + non_mdl_size = tf.subtract(self._output_size, mdl_size) + input_keys = tf.add(tf.floormod(keys, non_mdl_size), mdl_size) + + new_input = twml.SparseTensor( + ids=ids, indices=input_keys, values=vals, dense_shape=output_shape).to_tf() + + # concatenate discretizer output with original input + sparse_add = tf.sparse_add(new_input, output) + output = tf.SparseTensor(sparse_add.indices, sparse_add.values, output_shape) + + return output + + def compute_output_shape(self, input_shape): + """Computes the output shape of the layer given the input shape. + + Args: + input_shape: A (possibly nested tuple of) `TensorShape`. It need not + be fully defined (e.g. the batch size may be unknown). + + Raises NotImplementedError. + + """ + raise NotImplementedError diff --git a/twml/twml/layers/sequential.py b/twml/twml/layers/sequential.py new file mode 100644 index 000000000..c0d4b92cc --- /dev/null +++ b/twml/twml/layers/sequential.py @@ -0,0 +1,160 @@ +""" +Implementing Sequential Layer container +""" + + +from .layer import Layer + +from tensorflow import keras +from tensorflow.python.layers import base + + +class Sequential(Layer): + """ + A sequential stack of layers. + + Arguments: + layers: list of layers to add to the model. + + Output: + the output of the sequential layers + """ + + def __init__(self, layers=None, **kwargs): + self._layers = [] # Stack of layers. + self._layer_names = [] # Stack of layers names + self._layer_outputs = [] + # Add to the model any layers passed to the constructor. + if layers: + for layer in layers: + self.add(layer) + super(Sequential, self).__init__(**kwargs) + + def add(self, layer): + """Adds a layer instance on top of the layer stack. + + Arguments: + layer: + layer instance. + + Raises: + TypeError: + if the layer argument is not instance of base.Layer + """ + if not isinstance(layer, base.Layer) and not isinstance(layer, keras.layers.Layer): + raise TypeError('The added layer must be an instance of class Layer') + + if layer.name in self._layer_names: + raise ValueError('Layer with name %s already exists in sequential layer' % layer.name) + + self._layers.append(layer) + self._layer_names.append(layer.name) + + def pop(self): + """Removes the last layer in the model. + + Raises: + TypeError: + if there are no layers in the model. + """ + if not self._layers or not self._layer_names: + raise TypeError('There are no layers in the model.') + self._layers.pop() + self._layer_names.pop() + + def call(self, inputs, **kwargs): # pylint: disable=unused-argument + """The logic of the layer lives here. + + Arguments: + inputs: + input tensor(s). + + Returns: + The output of the sequential layers + """ + self._layer_outputs = [] + for layer in self._layers: + # don't use layer.call because you want to build individual layers + inputs = layer(inputs) # overwrites the current input after it has been processed + self._layer_outputs.append(inputs) + return inputs + + @property + def layers(self): + """ Return the layers in the sequential layer """ + return self._layers + + @property + def layer_names(self): + """ Return the layer names in the sequential layer """ + return self._layer_names + + @property + def layer_outputs(self): + """ Return the layer outputs in the sequential layer """ + return self._layer_outputs + + def get(self, key): + """Retrieves the n-th layer. + + Arguments: + key: + index of the layer + + Output: + The n-th layer where n is equal to the key. + """ + return self._layers[key] + + def get_output(self, key): + """Retrieves the n-th layer output. + + Arguments: + key: + index of the layer + + Output: + The intermediary output equivalent to the nth layer, where n is equal to the key. + """ + return self._layer_outputs[key] + + def get_layer_by_name(self, name): + """Retrieves the layer corresponding to the name. + + Arguments: + name: + name of the layer + + Output: + list of layers that have the name desired + """ + return self._layers[self._layer_names.index(name)] + + def get_layer_output_by_name(self, name): + """Retrieves the layer output corresponding to the name. + + Arguments: + name: + name of the layer + + Output: + list of the output of the layers that have the desired name + """ + return self._layer_outputs[self._layer_names.index(name)] + + @property + def init(self): + """ returns a list of initialization ops (one per layer) """ + return [layer.init for layer in self._layers] + + def compute_output_shape(self, input_shape): + """Computes the output shape of the layer given the input shape. + + Args: + input_shape: A (possibly nested tuple of) `TensorShape`. It need not + be fully defined (e.g. the batch size may be unknown). + + Raise NotImplementedError. + + """ + raise NotImplementedError diff --git a/twml/twml/layers/sparse_max_norm.py b/twml/twml/layers/sparse_max_norm.py new file mode 100644 index 000000000..e1f423fe0 --- /dev/null +++ b/twml/twml/layers/sparse_max_norm.py @@ -0,0 +1,221 @@ +# pylint: disable=no-member, attribute-defined-outside-init, duplicate-code +""" +Contains the twml.layers.SparseMaxNorm layer. +""" +from .layer import Layer + +from libtwml import OPLIB +import tensorflow.compat.v1 as tf +import twml + + +class SparseMaxNorm(Layer): + """ + Computes a max-normalization and adds bias to the sparse_input, + forwards that through a sparse affine transform followed + by an non-linear activation on the resulting dense representation. + + This layer has two parameters, one of which learns through gradient descent: + bias_x (optional): + vector of shape [input_size]. Learned through gradient descent. + max_x: + vector of shape [input_size]. Holds the maximas of input ``x`` for normalization. + Either calibrated through SparseMaxNorm calibrator, or calibrated online, or both. + + The pseudo-code for this layer looks like: + + .. code-block:: python + + abs_x = abs(x) + normed_x = clip_by_value(x / max_x, -1, 1) + biased_x = normed_x + bias_x + return biased + + + Args: + max_x_initializer: + initializer vector of shape [input_size] used by variable `max_x` + bias_x_initializer: + initializer vector of shape [input_size] used by parameter `bias_x` + is_training: + Are we training the layer to learn the normalization maximas. + If set to True, max_x will be able to learn. This is independent of bias_x + epsilon: + The minimum value used for max_x. Defaults to 1E-5. + use_bias: + Default True. Set to False to not use a bias term. + + Returns: + A layer representing the output of the sparse_max_norm transformation. + """ + + def __init__( + self, + input_size=None, + max_x_initializer=None, + bias_x_initializer=None, + is_training=True, + epsilon=1E-5, + use_bias=True, + **kwargs): + + super(SparseMaxNorm, self).__init__(**kwargs) + if input_size: + raise ValueError('input_size is deprecated - it is now automatically \ + inferred from your input.') + if max_x_initializer is None: + max_x_initializer = tf.zeros_initializer() + self.max_x_initializer = max_x_initializer + + self._use_bias = use_bias + if use_bias: + if bias_x_initializer is None: + bias_x_initializer = tf.zeros_initializer() + self.bias_x_initializer = bias_x_initializer + + self.epsilon = epsilon + self.is_training = is_training + + def build(self, input_shape): # pylint: disable=unused-argument + """Creates the max_x and bias_x tf.Variables of the layer.""" + + self.max_x = self.add_variable( + 'max_x', + initializer=self.max_x_initializer, + shape=[input_shape[1]], + dtype=tf.float32, + trainable=False) + + if self._use_bias: + self.bias_x = self.add_variable( + 'bias_x', + initializer=self.bias_x_initializer, + shape=[input_shape[1]], + dtype=tf.float32, + trainable=True) + + self.built = True + + def compute_output_shape(self, input_shape): + """Computes the output shape of the layer given the input shape. + + Args: + input_shape: A (possibly nested tuple of) `TensorShape`. It need not + be fully defined (e.g. the batch size may be unknown). + + Raises NotImplementedError. + + """ + raise NotImplementedError + + def _call(self, inputs, **kwargs): # pylint: disable=unused-argument + """ + The forward propagation logic of the layer lives here. + + Arguments: + sparse_input: + A 2D ``tf.SparseTensor`` of dense_shape ``[batch_size, input_size]`` + Returns: + A ``tf.SparseTensor`` representing the output of the max_norm transformation, this can + be fed into twml.layers.FullSparse in order to be transformed into a ``tf.Tensor``. + """ + + if isinstance(inputs, twml.SparseTensor): + inputs = inputs.to_tf() + elif not isinstance(inputs, tf.SparseTensor): + raise TypeError("The inputs must be of type tf.SparseTensor or twml.SparseTensor") + + indices_x = inputs.indices[:, 1] + values_x = inputs.values + + if self.is_training is False: + normalized_x = OPLIB.sparse_max_norm_inference(self.max_x, + indices_x, + values_x, + self.epsilon) + + update_op = tf.no_op() + else: + max_x, normalized_x = OPLIB.sparse_max_norm_training(self.max_x, + indices_x, + values_x, + self.epsilon) + + update_op = tf.assign(self.max_x, max_x) + + with tf.control_dependencies([update_op]): + normalized_x = tf.stop_gradient(normalized_x) + + # add input bias + if self._use_bias: + normalized_x = normalized_x + tf.gather(self.bias_x, indices_x) + + # convert back to sparse tensor + return tf.SparseTensor(inputs.indices, normalized_x, inputs.dense_shape) + + def call(self, inputs, **kwargs): # pylint: disable=unused-argument + """ + The forward propagation logic of the layer lives here. + + Arguments: + sparse_input: + A 2D ``tf.SparseTensor`` of dense_shape ``[batch_size, input_size]`` + Returns: + A ``tf.SparseTensor`` representing the output of the max_norm transformation, this can + be fed into twml.layers.FullSparse in order to be transformed into a ``tf.Tensor``. + """ + with tf.device(self.max_x.device): + return self._call(inputs, **kwargs) + +# For backwards compatiblity and also because I don't want to change all the tests. +MaxNorm = SparseMaxNorm + + +def sparse_max_norm(inputs, + input_size=None, + max_x_initializer=None, + bias_x_initializer=None, + is_training=True, + epsilon=1E-5, + use_bias=True, + name=None, + reuse=None): + """ + Functional inteface to SparseMaxNorm. + + Args: + inputs: + A sparse tensor (can be twml.SparseTensor or tf.SparseTensor) + input_size: + number of input units + max_x_initializer: + initializer vector of shape [input_size] used by variable `max_x` + bias_x_initializer: + initializer vector of shape [input_size] used by parameter `bias_x` + is_training: + Are we training the layer to learn the normalization maximas. + If set to True, max_x will be able to learn. This is independent of bias_x + epsilon: + The minimum value used for max_x. Defaults to 1E-5. + use_bias: + Default True. Set to False to not use a bias term. + + Returns: + Output after normalizing with the max value. + """ + if input_size: + raise ValueError('input_size is deprecated - it is now automatically \ + inferred from your input.') + + if isinstance(inputs, twml.SparseTensor): + inputs = inputs.to_tf() + + layer = SparseMaxNorm(max_x_initializer=max_x_initializer, + bias_x_initializer=bias_x_initializer, + is_training=is_training, + epsilon=epsilon, + use_bias=use_bias, + name=name, + _scope=name, + _reuse=reuse) + return layer(inputs) diff --git a/twml/twml/layers/stitch.py b/twml/twml/layers/stitch.py new file mode 100644 index 000000000..51dffdb8e --- /dev/null +++ b/twml/twml/layers/stitch.py @@ -0,0 +1,54 @@ +# pylint: disable=useless-super-delegation +""" +Implementing Stitch Layer +""" + + +from .layer import Layer + +import tensorflow.compat.v1 as tf + + +class Stitch(Layer): + """ + This layer is responsible for stitching a partioned layer together. + + Output: + A layer that performs stitching + """ + + def compute_output_shape(self, input_shape): + """Computes the output shape of the layer given the input shape. + + Args: + input_shape: A (possibly nested tuple of) `TensorShape`. It need not + be fully defined (e.g. the batch size may be unknown). + + Raises NotImplementedError. + + """ + raise NotImplementedError + + def call(self, partioned_val, partioned_keys, + partioned_indices, **kwargs): # pylint: disable=unused-argument, arguments-differ + """ + This layer is responsible for stitching a partioned layer together. + + Input: + partioned_val: + a list of partioned Tensors which represent the vals of the hashmap + partioned_keys: + a list of partioned Tensors which represent the keys of the hashmap + partioned_indices: + a list of partioned Tensors which represent the indices of the hashmap + Output: + List which contains: [output_vals, output_keys] + output_vals: + Values of the HashMap (float) + output_keys: + Keys of HashMap (float) + """ + indices = [tf.to_int32(index) for index in partioned_indices] + concat_keys = tf.dynamic_stitch(indices, partioned_keys) + concat_vals = tf.dynamic_stitch(indices, partioned_val) + return [concat_vals, concat_keys] diff --git a/twml/twml/learning_rate_decay.py b/twml/twml/learning_rate_decay.py new file mode 100644 index 000000000..be522d75b --- /dev/null +++ b/twml/twml/learning_rate_decay.py @@ -0,0 +1,168 @@ +# pylint: disable=too-many-branches +""" This module includes functions for managing learning rate decay """ +import tensorflow.compat.v1 as tf + + +def get_learning_rate_decay_fn(params): + """ + Returns a learning rate decay function that takes the initial + learning_rate and global_step + as arguments and returns the current learning rate. + + Currently supports params.learning_rate_decay values of: + exponential | polynomial | piecewise_constant | cosine | cosine restarts. + See `Decaying the Leanring Rate + `_ for details. + + Arguments: + params: + a tensorflow.contrib.train.HParams object containing the relevant hyperparameters. + """ + paramsv = params.values() + if 'learning_rate_decay' not in paramsv or params.learning_rate_decay == 'no_learning_rate_decay': + return None + elif params.learning_rate_decay == 'exponential_learning_rate_decay': + if 'decay_steps' not in paramsv: + raise ValueError("Expecting params.decay_steps for " + "params.learning_rate_decay == 'exponential'") + if 'exponential_decay_rate' not in paramsv: + raise ValueError("Expecting params.exponential_decay_rate for " + "params.learning_rate_decay == 'exponential'") + + def exponential_decay_fn(learning_rate, global_step): + """ exponential decay function to be passed to optimize_loss """ + return tf.train.exponential_decay( + learning_rate=learning_rate, + global_step=global_step, + decay_steps=params.decay_steps, + decay_rate=params.exponential_decay_rate + ) + return exponential_decay_fn + elif params.learning_rate_decay == 'piecewise_constant_learning_rate_decay': + if 'piecewise_constant_boundaries' not in paramsv: + raise ValueError("Expecting params.piecewise_constant_boundaries for " + "params.learning_rate_decay == 'piecewise_constant'") + if 'piecewise_constant_values' not in paramsv: + raise ValueError("Expecting params.piecewise_constant_values for " + "params.learning_rate_decay == 'piecewise_constant'") + # pylint: disable=unused-argument + + def piecewise_constant_fn(learning_rate, global_step): + """ piecewise_constant decay function to be passed to optimize_loss """ + return tf.train.piecewise_constant( + x=global_step, + boundaries=params.piecewise_constant_boundaries, + values=params.piecewise_constant_values + ) + return piecewise_constant_fn + elif params.learning_rate_decay == 'polynomial_learning_rate_decay': + if 'decay_steps' not in paramsv: + raise ValueError("Expecting params.decay_steps for " + "params.learning_rate_decay == 'polynomial'") + if 'end_learning_rate' not in paramsv: + raise ValueError("Expecting params.end_learning_rate for " + "params.learning_rate_decay == 'polynomial'") + + def polynomial_decay_fn(learning_rate, global_step): + """ polynomial decay function to be passed to optimize_loss """ + return tf.train.polynomial_decay( + learning_rate=learning_rate, + global_step=global_step, + decay_steps=params.decay_steps, + end_learning_rate=params.end_learning_rate, + power=params.polynomial_power if 'polynomial_power' in paramsv else 1.0, + ) + return polynomial_decay_fn + + elif params.learning_rate_decay == 'inverse_learning_rate_decay': + if 'min_learning_rate' not in paramsv: + raise ValueError("Expecting params.min_learning_rate for " + "params.learning_rate_decay == 'inverse'") + if 'decay_rate' not in paramsv: + raise ValueError("Expecting params.decay_rate for " + "params.learning_rate_decay == 'inverse'") + if 'decay_steps' not in paramsv: + raise ValueError("Expecting params.decay_steps for " + "params.learning_rate_decay == 'inverse'") + + def bounded_inverse_time_decay_fn(learning_rate, global_step): + ''' + Returns the decayed learning_rate by applying the function: + decayed_lr = max(lr /(1 + decay_rate * floor(global_step /decay_step)), + min_learning_rate) + Arguments: + learning_rate: + A scalar `float32` or `float64` `Tensor` or a Python number. + The initial learning rate. + global_step: + A scalar `int32` or `int64` `Tensor` or a Python number. + Global step to use for the decay computation. Must not be negative. + min_learning_rate: + A scalar `int32` or `int64` `Tensor` or a Python number. + Minimum possible learning_rate. The decayed learning_rate will not be + smaller than the min_learning_rate + decay_steps: + How often to apply decay. In dbv1, this should be 1. + decay_rate: + A scalar `int32` or `int64` `Tensor` or a Python number. + Rate in which we decay the learning rate. + Returns: + A scalar `Tensor` of the same type as `learning_rate`. The decayed + learning rate. + ''' + decayed_rate = tf.train.inverse_time_decay( + learning_rate=learning_rate, + global_step=global_step, + decay_steps=params.decay_steps, + decay_rate=params.decay_rate) + # Getting dtype of returned Tensor + dtype = decayed_rate.dtype + # Casting the min_learning rate the same dtype as decayes rate + min_learning_rate = tf.cast(params.min_learning_rate, dtype) + # Returning the maximum between the two + return tf.maximum(decayed_rate, min_learning_rate) + + return bounded_inverse_time_decay_fn + + elif params.learning_rate_decay == 'cosine_learning_rate_decay': + if 'decay_steps' not in paramsv: + raise ValueError("Expecting params.decay_steps for " + "params.learning_rate_decay == 'cosine_decay'") + if "alpha" not in paramsv: + raise ValueError("Expecting params.alpha for " + "params.learning_rate_decay == 'cosine_decay'") + def cosine_decay_fn(learning_rate, global_step): + """ cosine decay function to be passed to optimize_loss """ + return tf.train.cosine_decay( + learning_rate=learning_rate, + global_step=global_step, + decay_steps=params.decay_steps, + alpha=params.alpha + ) + return cosine_decay_fn + elif params.learning_rate_decay == 'cosine_restarts_learning_rate_decay': + if 'first_decay_steps' not in paramsv: + raise ValueError("Expecting params.first_decay_steps for " + "params.learning_rate_decay == 'cosine_restarts_decay'") + if 't_mul' not in paramsv: + raise ValueError("Expecting params.t_mul for " + "params.learning_rate_decay == 'cosine_restarts_decay'") + if 'm_mul' not in paramsv: + raise ValueError("Expecting params.m_mul for " + "params.learning_rate_decay == 'cosine_restarts_decay'") + if "alpha" not in paramsv: + raise ValueError("Expecting params.alpha for " + "params.learning_rate_decay == 'cosine_restarts_decay'") + def cosine_restart_decay_fn(learning_rate, global_step): + """ cosine decay function to be passed to optimize_loss """ + return tf.train.cosine_decay_restarts( + learning_rate=learning_rate, + global_step=global_step, + first_decay_steps=params.first_decay_steps, + t_mul=params.t_mul, + m_mul=params.m_mul, + alpha=params.alpha + ) + return cosine_restart_decay_fn + + raise ValueError("Unsupported params.learning_rate_decay: %s" % params.learning_rate_decay) diff --git a/twml/twml/lookup/__init__.py b/twml/twml/lookup/__init__.py new file mode 100644 index 000000000..87392d719 --- /dev/null +++ b/twml/twml/lookup/__init__.py @@ -0,0 +1,11 @@ +from tensorflow.python.ops.lookup_ops import ( + index_table_from_file, + index_table_from_tensor, + index_to_string_table_from_file +) # noqa: F401 + + +""" +NOTE: Using `from tensorflow.python.ops.lookup_ops import index_table_from_tensor` in the code works. +This stub exists because it was easier to refactor code because twml is widely used. +""" diff --git a/twml/twml/metrics.py b/twml/twml/metrics.py new file mode 100644 index 000000000..ee2f82b74 --- /dev/null +++ b/twml/twml/metrics.py @@ -0,0 +1,1380 @@ +""" +This module contains custom tensorflow metrics used at Twitter. +Its components conform to conventions used by the ``tf.metrics`` module. + +""" + +from collections import OrderedDict +from functools import partial + +import numpy as np +import tensorboard as tb +import tensorflow.compat.v1 as tf + + +CLAMP_EPSILON = 0.00001 + + +def total_weight_metric( + labels, + predictions, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + with tf.variable_scope(name, 'total_weight', (labels, predictions, weights)): + total_weight = _metric_variable(name='total_weight', shape=[], dtype=tf.float64) + + if weights is None: + weights = tf.cast(tf.size(labels), total_weight.dtype, name="default_weight") + else: + weights = tf.cast(weights, total_weight.dtype) + + # add up the weights to get total weight of the eval set + update_total_weight = tf.assign_add(total_weight, tf.reduce_sum(weights), name="update_op") + + value_op = tf.identity(total_weight) + update_op = tf.identity(update_total_weight) + + if metrics_collections: + tf.add_to_collections(metrics_collections, value_op) + + if updates_collections: + tf.add_to_collections(updates_collections, update_op) + + return value_op, update_op + + +def num_samples_metric( + labels, + predictions, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + with tf.variable_scope(name, 'num_samples', (labels, predictions, weights)): + num_samples = _metric_variable(name='num_samples', shape=[], dtype=tf.float64) + update_num_samples = tf.assign_add(num_samples, tf.cast(tf.size(labels), num_samples.dtype), name="update_op") + + value_op = tf.identity(num_samples) + update_op = tf.identity(update_num_samples) + + if metrics_collections: + tf.add_to_collections(metrics_collections, value_op) + + if updates_collections: + tf.add_to_collections(updates_collections, update_op) + + return value_op, update_op + + +def ctr(labels, predictions, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + # pylint: disable=unused-argument + """ + Compute the weighted average positive sample ratio based on labels + (i.e. weighted average percentage of positive labels). + The name `ctr` (click-through-rate) is from legacy. + + Args: + labels: the ground truth value. + predictions: the predicted values, whose shape must match labels. Ignored for CTR computation. + weights: optional weights, whose shape must match labels . Weight is 1 if not set. + metrics_collections: optional list of collections to add this metric into. + updates_collections: optional list of collections to add the associated update_op into. + name: an optional variable_scope name. + + Return: + ctr: A `Tensor` representing positive sample ratio. + update_op: A update operation used to accumulate data into this metric. + """ + return tf.metrics.mean( + values=labels, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) + + +def predicted_ctr(labels, predictions, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + # pylint: disable=unused-argument + """ + Compute the weighted average positive ratio based on predictions, + (i.e. weighted averaged predicted positive probability). + The name `ctr` (click-through-rate) is from legacy. + + Args: + labels: the ground truth value. + predictions: the predicted values, whose shape must match labels. Ignored for CTR computation. + weights: optional weights, whose shape must match labels . Weight is 1 if not set. + metrics_collections: optional list of collections to add this metric into. + updates_collections: optional list of collections to add the associated update_op into. + name: an optional variable_scope name. + + Return: + predicted_ctr: A `Tensor` representing the predicted positive ratio. + update_op: A update operation used to accumulate data into this metric. + """ + return tf.metrics.mean( + values=predictions, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) + + +def prediction_std_dev(labels, predictions, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """ + Compute the weighted standard deviation of the predictions. + Note - this is not a confidence interval metric. + + Args: + labels: the ground truth value. + predictions: the predicted values, whose shape must match labels. Ignored for CTR computation. + weights: optional weights, whose shape must match labels . Weight is 1 if not set. + metrics_collections: optional list of collections to add this metric into. + updates_collections: optional list of collections to add the associated update_op into. + name: an optional variable_scope name. + + Return: + metric value: A `Tensor` representing the value of the metric on the data accumulated so far. + update_op: A update operation used to accumulate data into this metric. + """ + with tf.variable_scope(name, 'pred_std_dev', (labels, predictions, weights)): + labels = tf.cast(labels, tf.float64) + predictions = tf.cast(predictions, tf.float64) + + if weights is None: + weights = tf.ones(shape=tf.shape(labels), dtype=tf.float64, name="default_weight") + else: + weights = tf.cast(weights, tf.float64) + + # State kept during streaming of examples + total_weighted_preds = _metric_variable( + name='total_weighted_preds', shape=[], dtype=tf.float64) + total_weighted_preds_sq = _metric_variable( + name='total_weighted_preds_sq', shape=[], dtype=tf.float64) + total_weights = _metric_variable( + name='total_weights', shape=[], dtype=tf.float64) + + # Update state + update_total_weighted_preds = tf.assign_add(total_weighted_preds, tf.reduce_sum(weights * predictions)) + update_total_weighted_preds_sq = tf.assign_add(total_weighted_preds_sq, tf.reduce_sum(weights * predictions * predictions)) + update_total_weights = tf.assign_add(total_weights, tf.reduce_sum(weights)) + + # Compute output + def compute_output(tot_w, tot_wp, tot_wpp): + return tf.math.sqrt(tot_wpp / tot_w - (tot_wp / tot_w) ** 2) + std_dev_est = compute_output(total_weights, total_weighted_preds, total_weighted_preds_sq) + update_std_dev_est = compute_output(update_total_weights, update_total_weighted_preds, update_total_weighted_preds_sq) + + if metrics_collections: + tf.add_to_collections(metrics_collections, std_dev_est) + + if updates_collections: + tf.add_to_collections(updates_collections, update_std_dev_est) + + return std_dev_est, update_std_dev_est + + +def _get_arce_predictions(predictions, weights, label_weighted, labels, + up_weight, deprecated_rce, + total_positive, update_total_positive): + """ + Returns the ARCE predictions, total_positive, update_total_positive and weights + used by the rest of the twml.metrics.rce metric computation. + """ + predictions_weighted = tf.multiply(predictions, weights, name="weighted_preds") + label_weighted_comp = tf.subtract(tf.reduce_sum(weights), tf.reduce_sum(label_weighted)) + pred_weight_comp = tf.subtract(tf.reduce_sum(weights), tf.reduce_sum(predictions_weighted)) + normalizer_comp = label_weighted_comp / pred_weight_comp + + if up_weight is False: + total_positive_unweighted = _metric_variable( + name='total_positive_unweighted', shape=[], dtype=tf.float32) + + update_total_positive_unweighted = tf.assign_add( + total_positive_unweighted, tf.reduce_sum(labels), + name="total_positive_unweighted_update") + + if deprecated_rce: + normalizer = tf.reduce_sum(labels) / tf.reduce_sum(label_weighted) + else: + # sum of labels / sum of weighted labels + normalizer = update_total_positive_unweighted / update_total_positive + + label_comp = tf.subtract(tf.to_float(tf.size(labels)), tf.reduce_sum(labels)) + normalizer_comp = label_comp / label_weighted_comp + + # note that up_weight=True changes these for the rest of the twml.metric.rce computation + weights = tf.ones(shape=tf.shape(labels), dtype=tf.float32, name="default_weight") + total_positive = total_positive_unweighted + update_total_positive = update_total_positive_unweighted + else: + if deprecated_rce: + normalizer = tf.reduce_sum(label_weighted) / tf.reduce_sum(predictions_weighted) + else: + # normalizer used for NRCE (and ARCE with up_weight=True) + total_prediction = _metric_variable(name='total_prediction', shape=[], dtype=tf.float32) + + # update the variable holding the sum of weighted predictions + update_total_prediction = tf.assign_add( + total_prediction, tf.reduce_sum(predictions_weighted), name="total_prediction_update") + + # this used to be tf.reduce_sum(label_weighted) / tf.reduce_sum(predictions_weighted) + # but it measure normalizer over batch was too flawed an approximation. + normalizer = update_total_positive / update_total_prediction + + pred_comp = tf.subtract(tf.ones(shape=tf.shape(labels), dtype=tf.float32), predictions) + pred_comp_norm = tf.multiply(pred_comp, normalizer_comp, name="normalized_predictions_comp") + pred_num = tf.multiply(predictions, normalizer, name="normalized_pred_numerator") + pred_denom = tf.add(pred_num, pred_comp_norm, name="normalized_pred_denominator") + predictions = pred_num / pred_denom + + return predictions, total_positive, update_total_positive, weights + + +def rce(labels, predictions, + weights=None, + normalize=False, + arce=False, + up_weight=True, + metrics_collections=None, + updates_collections=None, + name=None, + deprecated_rce=False): + """ + Compute the relative cross entropy (RCE). + The RCE is a relative measurement compared to the baseline model's performance. + The baseline model always predicts average click-through-rate (CTR). + The RCE measures, in percentage, how much better the predictions are, compared + to the baseline model, in terms of cross entropy loss. + + y = label; p = prediction; + binary cross entropy = y * log(p) + (1-y) * log(1-p) + + Args: + labels: + the ground true value. + predictions: + the predicted values, whose shape must match labels. + weights: + optional weights, whose shape must match labels . Weight is 1 if not set. + normalize: + if set to true, produce NRCEs used at Twitter. (normalize preds by weights first) + NOTE: if you don't understand what NRCE is, please don't use it. + arce: + if set to true, produces `ARCE `_. + This can only be activated if `normalize=True`. + up_weight: + if set to true, produces arce in the up_weighted space (considers CTR after up_weighting + data), while False gives arce in the original space (only considers CTR before up_weighting). + In the actual version, this flag can only be activated if arce is True. + Notice that the actual version of NRCE corresponds to up_weight=True. + metrics_collections: + optional list of collections to add this metric into. + updates_collections: + optional list of collections to add the associated update_op into. + name: + an optional variable_scope name. + deprecated_rce: + enables the previous NRCE/ARCE calculations which calculated some label metrics + on the batch instead of on all batches seen so far. Note that the older metric + calculation is less stable, especially for smaller batch sizes. You should probably + never have to set this to True. + + Return: + rce_value: + A ``Tensor`` representing the RCE. + update_op: + A update operation used to accumulate data into this metric. + + .. note:: Must have at least 1 positive and 1 negative sample accumulated, + or RCE will come out as NaN. + """ + with tf.variable_scope(name, 'rce', (labels, predictions, weights)): + labels = tf.to_float(labels, name="label_to_float") + predictions = tf.to_float(predictions, name="predictions_to_float") + + if weights is None: + weights = tf.ones(shape=tf.shape(labels), dtype=tf.float32, name="default_weight") + else: + weights = tf.to_float(weights, name="weight_to_float") + + total_positive = _metric_variable(name='total_positive', shape=[], dtype=tf.float32) + total_loss = _metric_variable(name='total_loss', shape=[], dtype=tf.float32) + total_weight = _metric_variable(name='total_weight', shape=[], dtype=tf.float32) + + label_weighted = tf.multiply(labels, weights, name="weighted_label") + + update_total_positive = tf.assign_add( + total_positive, tf.reduce_sum(label_weighted), name="total_pos_update") + + if arce: + if normalize is False: + raise ValueError('This configuration of parameters is not actually allowed') + + predictions, total_positive, update_total_positive, weights = _get_arce_predictions( + predictions=predictions, weights=weights, deprecated_rce=deprecated_rce, + label_weighted=label_weighted, labels=labels, up_weight=up_weight, + total_positive=total_positive, update_total_positive=update_total_positive) + + elif normalize: + predictions_weighted = tf.multiply(predictions, weights, name="weighted_preds") + + if deprecated_rce: + normalizer = tf.reduce_sum(label_weighted) / tf.reduce_sum(predictions_weighted) + else: + total_prediction = _metric_variable(name='total_prediction', shape=[], dtype=tf.float32) + + # update the variable holding the sum of weighted predictions + update_total_prediction = tf.assign_add( + total_prediction, tf.reduce_sum(predictions_weighted), name="total_prediction_update") + + # this used to be tf.reduce_sum(label_weighted) / tf.reduce_sum(predictions_weighted) + # but it measure normalizer over batch was too flawed an approximation. + normalizer = update_total_positive / update_total_prediction + + # NRCE + predictions = tf.multiply(predictions, normalizer, name="normalized_predictions") + + # clamp predictions to keep log(p) stable + clip_p = tf.clip_by_value(predictions, CLAMP_EPSILON, 1.0 - CLAMP_EPSILON, name="clip_p") + logloss = _binary_cross_entropy(pred=clip_p, target=labels, name="logloss") + + logloss_weighted = tf.multiply(logloss, weights, name="weighted_logloss") + + update_total_loss = tf.assign_add( + total_loss, tf.reduce_sum(logloss_weighted), name="total_loss_update") + update_total_weight = tf.assign_add( + total_weight, tf.reduce_sum(weights), name="total_weight_update") + + # metric value retrieval subgraph + ctr1 = tf.truediv(total_positive, total_weight, name="ctr") + # Note: we don't have to keep running averages for computing baseline CE. Because the prediction + # is constant for every sample, we can simplify it to the formula below. + baseline_ce = _binary_cross_entropy(pred=ctr1, target=ctr1, name="baseline_ce") + pred_ce = tf.truediv(total_loss, total_weight, name="pred_ce") + + rce_t = tf.multiply( + 1.0 - tf.truediv(pred_ce, baseline_ce), + 100, + name="rce") + + # metric update subgraph + ctr2 = tf.truediv(update_total_positive, update_total_weight, name="ctr_update") + # Note: we don't have to keep running averages for computing baseline CE. Because the prediction + # is constant for every sample, we can simplify it to the formula below. + baseline_ce2 = _binary_cross_entropy(pred=ctr2, target=ctr2, name="baseline_ce_update") + pred_ce2 = tf.truediv(update_total_loss, update_total_weight, name="pred_ce_update") + + update_op = tf.multiply( + 1.0 - tf.truediv(pred_ce2, baseline_ce2), + 100, + name="update_op") + + if metrics_collections: + tf.add_to_collections(metrics_collections, rce_t) + + if updates_collections: + tf.add_to_collections(updates_collections, update_op) + + return rce_t, update_op + + +def ce(p_true, p_est=None): + if p_est is None: + p_est = p_true + return _binary_cross_entropy(pred=p_est, target=p_true, name=None) + + +def rce_transform(outputs, labels, weights): + ''' + Construct an OrderedDict of quantities to aggregate over eval batches + outputs, labels, weights are TensorFlow tensors, and are assumed to + be of shape [N] for batch_size = N + Each entry in the output OrderedDict should also be of shape [N] + ''' + out_vals = OrderedDict() + out_vals['weighted_loss'] = weights * ce(p_true=labels, p_est=outputs) + out_vals['weighted_labels'] = labels * weights + out_vals['weight'] = weights + return out_vals + + +def rce_metric(aggregates): + ''' + input ``aggregates`` is an OrderedDict with the same keys as those created + by rce_transform(). The dict values are the aggregates (reduce_sum) + of the values produced by rce_transform(), and should be scalars. + output is the value of RCE + ''' + # cummulative weighted loss of model predictions + total_weighted_loss = aggregates['weighted_loss'] + total_weighted_labels = aggregates['weighted_labels'] + total_weight = aggregates['weight'] + + model_average_loss = total_weighted_loss / total_weight + baseline_average_loss = ce(total_weighted_labels / total_weight) + return 100.0 * (1 - model_average_loss / baseline_average_loss) + + +def metric_std_err(labels, predictions, + weights=None, + transform=rce_transform, metric=rce_metric, + metrics_collections=None, + updates_collections=None, + name='rce_std_err'): + """ + Compute the weighted standard error of the RCE metric on this eval set. + This can be used for confidence intervals and unpaired hypothesis tests. + + Args: + labels: the ground truth value. + predictions: the predicted values, whose shape must match labels. + weights: optional weights, whose shape must match labels . Weight is 1 if not set. + transform: a function of the following form: + + .. code-block:: python + + def transform(outputs, labels, weights): + out_vals = OrderedDict() + ... + return out_vals + + where outputs, labels, and weights are all tensors of shape [eval_batch_size]. + The returned OrderedDict() should have values that are tensors of shape [eval_batch_size]. + These will be aggregated across many batches in the eval dataset, to produce + one scalar value per key of out_vals. + metric: a function of the following form + + .. code-block:: python + + def metric(aggregates): + ... + return metric_value + + where aggregates is an OrderedDict() having the same keys created by transform(). + Each of the corresponding dict values is the reduce_sum of the values produced by + transform(), and is a TF scalar. The return value should be a scalar representing + the value of the desired metric. + metrics_collections: optional list of collections to add this metric into. + updates_collections: optional list of collections to add the associated update_op into. + name: an optional variable_scope name. + + Return: + metric value: A `Tensor` representing the value of the metric on the data accumulated so far. + update_op: A update operation used to accumulate data into this metric. + """ + with tf.variable_scope(name, 'metric_std_err', (labels, predictions, weights)): + labels = tf.cast(labels, tf.float64) + predictions = tf.cast(predictions, tf.float64) + + if weights is None: + weights = tf.ones_like(labels, dtype=tf.float64, name="default_weight") + else: + weights = tf.cast(weights, tf.float64) + + labels = tf.reshape(labels, [-1]) + predictions = tf.reshape(predictions, [-1]) + predictions = tf.clip_by_value(predictions, CLAMP_EPSILON, 1.0 - CLAMP_EPSILON, name="clip_p") + weights = tf.reshape(weights, [-1]) + + # first apply the supplied transform function to the output, label, weight data + # returns an OrderedDict of 1xN tensors for N input samples + # for each sample, compute f = transform(pred, l, w) + transformed = transform(predictions, labels, weights) + + # we track 3 types of aggregate information + # 1. total number of samples + # 2. aggregated transformed samples (moment1), i.e. sum(f) + # 3. aggregated crosses of transformed samples (moment2), i.e. sum(f*f^T) + + # count total number of samples + sample_count = _metric_variable( + name='sample_count', shape=[], dtype=tf.int64) + update_sample_count = tf.assign_add(sample_count, tf.size(labels, out_type=sample_count.dtype)) + + # compose the ordered dict into a single vector + # so f can be treated as a single column vector rather than a collection of scalars + N = len(transformed) + transformed_vec = tf.stack(list(transformed.values()), axis=1) + + # compute and update transformed samples (1st order statistics) + # i.e. accumulate f into F as F += sum(f) + aggregates_1 = _metric_variable( + name='aggregates_1', shape=[N], dtype=tf.float64) + update_aggregates_1 = tf.assign_add(aggregates_1, tf.reduce_sum(transformed_vec, axis=0)) + + # compute and update crossed transformed samples (2nd order statistics) + # i.e. accumulate f*f^T into F2 as F2 += sum(f*transpose(f)) + aggregates_2 = _metric_variable( + name='aggregates_2', shape=[N, N], dtype=tf.float64) + moment_2_temp = ( + tf.reshape(transformed_vec, shape=[-1, N, 1]) + * tf.reshape(transformed_vec, shape=[-1, 1, N]) + ) + update_aggregates_2 = tf.assign_add(aggregates_2, tf.reduce_sum(moment_2_temp, axis=0)) + + def compute_output(agg_1, agg_2, samp_cnt): + # decompose the aggregates back into a dict to pass to the user-supplied metric fn + aggregates_dict = OrderedDict() + for i, key in enumerate(transformed.keys()): + aggregates_dict[key] = agg_1[i] + + metric_value = metric(aggregates_dict) + + # derivative of metric with respect to the 1st order aggregates + # i.e. d M(agg1) / d agg1 + metric_prime = tf.gradients(metric_value, agg_1, stop_gradients=agg_1) + + # estimated covariance of agg_1 + # cov(F) = sum(f*f^T) - (sum(f) * sum(f)^T) / N + # = agg_2 - (agg_1 * agg_1^T) / N + N_covariance_estimate = agg_2 - ( + tf.reshape(agg_1, shape=[-1, 1]) + @ tf.reshape(agg_1, shape=[1, -1]) + / tf.cast(samp_cnt, dtype=tf.float64) + ) + + # push N_covariance_estimate through a linearization of metric around agg_1 + # metric var = transpose(d M(agg1) / d agg1) * cov(F) * (d M(agg1) / d agg1) + metric_variance = ( + tf.reshape(metric_prime, shape=[1, -1]) + @ N_covariance_estimate + @ tf.reshape(metric_prime, shape=[-1, 1]) + ) + # result should be a single element, but the matmul is 2D + metric_variance = metric_variance[0][0] + metric_stderr = tf.sqrt(metric_variance) + return metric_stderr + + metric_stderr = compute_output(aggregates_1, aggregates_2, sample_count) + update_metric_stderr = compute_output(update_aggregates_1, update_aggregates_2, update_sample_count) + + if metrics_collections: + tf.add_to_collections(metrics_collections, metric_stderr) + + if updates_collections: + tf.add_to_collections(updates_collections, update_metric_stderr) + + return metric_stderr, update_metric_stderr + + +def lolly_nrce(labels, predictions, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """ + Compute the Lolly NRCE. + + Note: As this NRCE calculation uses Taylor expansion, it becomes inaccurate when the ctr is large, + especially when the adjusted ctr goes above 1.0. + + Calculation: + + :: + + NRCE: lolly NRCE + BCE: baseline cross entropy + NCE: normalized cross entropy + CE: cross entropy + y_i: label of example i + p_i: prediction of example i + y: ctr + p: average prediction + a: normalizer + + Assumes any p_i and a * p_i is within [0, 1) + NRCE = (1 - NCE / BCE) * 100 + BCE = - sum_i(y_i * log(y) + (1 - y_i) * log(1 - y)) + = - (y * log(y) + (1 - y) * log(1 - y)) + a = y / p + CE = - sum_i(y_i * log(p_i) + (1 - y_i) * log(1 - p_i)) + NCE = - sum_i(y_i * log(a * p_i) + (1 - y_i) * log(1 - a * p_i)) + = - sum_i(y_i * log(p_i) + (1 - y_i) * log(1 - p_i)) + - sum_i(y_i * log(a)) + + sum_i((1 - y_i) * log(1 - p_i)) + - sum_i((1 - y_i) * log(1 - a * p_i)) + ~= CE - sum_i(y_i) * log(a) + + sum_i((1 - y_i) * (- sum_{j=1~5}(p_i^j / j))) + - sum_i((1 - y_i) * (- sum_{j=1~5}(a^j * p_i^j / j))) + # Takes 5 items from the Taylor expansion, can be increased if needed + # Error for each example is O(p_i^6) + = CE - sum_i(y_i) * log(a) + - sum_{j=1~5}(sum_i((1 - y_i) * p_i^j) / j) + + sum_{j=1~5}(sum_i((1 - y_i) * p_i^j) * a^j / j) + = CE - sum_i(y_i) * log(a) + + sum_{j=1~5}(sum_i((1 - y_i) * p_i^j) * (a^j - 1) / j) + + Thus we keep track of CE, sum_i(y_i), sum_i((1 - y_i) * p_i^j) for j=1~5. + We also keep track of p and y by sum_i(y_i), sum_i(p_i), sum_i(1) so that + we can get a at the end, which leads to this NRCE. + + NRCE uses ctr and average pctr to normalize the pctrs. + It removes the impact of prediction error from RCE. + Usually NRCE is higher as the prediction error impact on RCE is negative. + Removing prediction error in our model can make RCE closer to NRCE and thus improve RCE. + + In Lolly NRCE we use ctr and average pctr of the whole dataset. + We thus remove the dataset level error in NRCE calculation. + In this case, when we want to improve RCE to the level of NRCE, + it is achievable as dataset level prediction error is easy to remove by calibration. + Lolly NRCE is thus a good estimate about the potential gain by adding calibration. + + In DBv2 NRCE, we use per-batch ctr and average pctr. We remove the batch level error. + This error is difficult to remove by modeling improvement, + at least not by simple calibration. + It thus cannot indicate the same opportunity as the Lolly NRCE does. + + Args: + labels: + the ground true value. + predictions: + the predicted values, whose shape must match labels. + weights: + optional weights, whose shape must match labels . Weight is 1 if not set. + metrics_collections: + optional list of collections to add this metric into. + updates_collections: + optional list of collections to add the associated update_op into. + name: + an optional variable_scope name. + + Return: + rce_value: + A ``Tensor`` representing the RCE. + update_op: + A update operation used to accumulate data into this metric. + + Note: Must have at least 1 positive and 1 negative sample accumulated, + or NRCE will come out as NaN. + """ + with tf.variable_scope(name, "lolly_nrce", (labels, predictions, weights)): + labels = tf.to_float(labels, name="label_to_float") + predictions = tf.to_float(predictions, name="predictions_to_float") + + if weights is None: + weights = tf.ones(shape=tf.shape(labels), dtype=tf.float32, name="default_weight") + else: + weights = tf.to_float(weights, name="weight_to_float") + + positive_weights = tf.multiply(labels, weights, name="positive_weights") + + # clamp predictions to keep log(p) stable + clip_predictions = tf.clip_by_value( + predictions, + CLAMP_EPSILON, + 1.0 - CLAMP_EPSILON, + name="clip_predictions") + weighted_predictions = tf.multiply( + predictions, weights, + name="weighted_predictions") + + logloss = _binary_cross_entropy(pred=clip_predictions, target=labels, name="logloss") + weighted_logloss = tf.multiply(logloss, weights, name="weighted_logloss") + + negatives = tf.subtract( + tf.ones(shape=tf.shape(labels), dtype=tf.float32), + labels, + name="negatives") + negative_predictions = tf.multiply( + predictions, + negatives, + name="negative_predictions") + weighted_negative_predictions = tf.multiply( + negative_predictions, weights, + name="weighted_negative_predictions") + negative_squared_predictions = tf.multiply( + negative_predictions, + negative_predictions, + name="negative_squared_predictions") + weighted_negative_squared_predictions = tf.multiply( + negative_squared_predictions, weights, + name="weighted_negative_squared_predictions") + negative_cubed_predictions = tf.multiply( + negative_squared_predictions, + negative_predictions, + name="negative_cubed_predictions") + weighted_negative_cubed_predictions = tf.multiply( + negative_cubed_predictions, weights, + name="weighted_negative_cubed_predictions") + negative_quartic_predictions = tf.multiply( + negative_cubed_predictions, + negative_predictions, + name="negative_quartic_predictions") + weighted_negative_quartic_predictions = tf.multiply( + negative_quartic_predictions, weights, + name="weighted_negative_quartic_predictions") + negative_quintic_predictions = tf.multiply( + negative_quartic_predictions, + negative_predictions, + name="negative_quintic_predictions") + weighted_negative_quintic_predictions = tf.multiply( + negative_quintic_predictions, weights, + name="weighted_negative_quintic_predictions") + + # Tracked stats + total_positive = _metric_variable(name="total_positive", shape=[], dtype=tf.float32) + total_weight = _metric_variable(name="total_weight", shape=[], dtype=tf.float32) + + total_prediction = _metric_variable(name="total_prediction", shape=[], dtype=tf.float32) + + total_negative_prediction = _metric_variable( + name="total_negative_prediction", + shape=[], dtype=tf.float32) + total_negative_squared_prediction = _metric_variable( + name="total_negative_squared_prediction", + shape=[], dtype=tf.float32) + total_negative_cubed_prediction = _metric_variable( + name="total_negative_cubed_prediction", + shape=[], dtype=tf.float32) + total_negative_quartic_prediction = _metric_variable( + name="total_negative_quartic_prediction", + shape=[], dtype=tf.float32) + total_negative_quintic_prediction = _metric_variable( + name="total_negative_quintic_prediction", + shape=[], dtype=tf.float32) + + total_loss = _metric_variable(name="total_loss", shape=[], dtype=tf.float32) + + # Update tracked stats + update_total_positive = tf.assign_add( + total_positive, tf.reduce_sum(positive_weights), name="total_positive_update") + update_total_weight = tf.assign_add( + total_weight, tf.reduce_sum(weights), name="total_weight_update") + update_total_prediction = tf.assign_add( + total_prediction, tf.reduce_sum(weighted_predictions), name="total_prediction_update") + update_total_negative_prediction = tf.assign_add( + total_negative_prediction, + tf.reduce_sum(weighted_negative_predictions), name="total_negative_prediction_update") + update_total_negative_squared_prediction = tf.assign_add( + total_negative_squared_prediction, + tf.reduce_sum(weighted_negative_squared_predictions), + name="total_negative_squared_prediction_update") + update_total_negative_cubed_prediction = tf.assign_add( + total_negative_cubed_prediction, + tf.reduce_sum(weighted_negative_cubed_predictions), + name="total_negative_cubed_prediction_update") + update_total_negative_quartic_prediction = tf.assign_add( + total_negative_quartic_prediction, + tf.reduce_sum(weighted_negative_quartic_predictions), + name="total_negative_quartic_prediction_update") + update_total_negative_quintic_prediction = tf.assign_add( + total_negative_quintic_prediction, + tf.reduce_sum(weighted_negative_quintic_predictions), + name="total_negative_quintic_prediction_update") + update_total_loss = tf.assign_add( + total_loss, tf.reduce_sum(weighted_logloss), name="total_loss_update") + + # metric value retrieval subgraph + # ctr of this batch + positive_rate = tf.truediv(total_positive, total_weight, name="positive_rate") + # Note: we don't have to keep running averages for computing baseline CE. Because the prediction + # is constant for every sample, we can simplify it to the formula below. + baseline_loss = _binary_cross_entropy( + pred=positive_rate, + target=positive_rate, + name="baseline_loss") + + # normalizing ratio for nrce + # calculated using total ctr and pctr so the last batch has the dataset ctr and pctr + normalizer = tf.truediv(total_positive, total_prediction, name="normalizer") + # Taylor expansion to calculate nl = - sum(y * log(p * a) + (1 - y) * log (1 - p * a)) + # log(1 - p * a) = -sum_{i=1~+inf}(a^i * x^i / i) + # log(1 - p) = -sum_{i=1~+inf}(a^i * x^i / i) + normalized_loss = ( + total_loss - + total_positive * tf.log(normalizer) + + total_negative_prediction * (normalizer - 1) + + total_negative_squared_prediction * (normalizer * normalizer - 1) / 2 + + total_negative_cubed_prediction * + (normalizer * normalizer * normalizer - 1) / 3 + + total_negative_quartic_prediction * + (normalizer * normalizer * normalizer * normalizer - 1) / 4 + + total_negative_quintic_prediction * + (normalizer * normalizer * normalizer * normalizer * normalizer - 1) / 5) + + # average normalized loss + avg_loss = tf.truediv(normalized_loss, total_weight, name="avg_loss") + + nrce_t = tf.multiply( + 1.0 - tf.truediv(avg_loss, baseline_loss), + 100, + name="lolly_nrce") + + # metric update subgraph + update_positive_rate = tf.truediv( + update_total_positive, + update_total_weight, + name="update_positive_rate") + # Note: we don't have to keep running averages for computing baseline CE. Because the prediction + # is constant for every sample, we can simplify it to the formula below. + update_baseline_loss = _binary_cross_entropy( + pred=update_positive_rate, + target=update_positive_rate, + name="update_baseline_loss") + + update_normalizer = tf.truediv( + update_total_positive, + update_total_prediction, + name="update_normalizer") + update_normalized_loss = ( + update_total_loss - + update_total_positive * tf.log(update_normalizer) + + update_total_negative_prediction * + (update_normalizer - 1) + + update_total_negative_squared_prediction * + (update_normalizer * update_normalizer - 1) / 2 + + update_total_negative_cubed_prediction * + (update_normalizer * update_normalizer * update_normalizer - 1) / 3 + + update_total_negative_quartic_prediction * + (update_normalizer * update_normalizer * update_normalizer * + update_normalizer - 1) / 4 + + update_total_negative_quintic_prediction * + (update_normalizer * update_normalizer * update_normalizer * + update_normalizer * update_normalizer - 1) / 5) + + update_avg_loss = tf.truediv( + update_normalized_loss, + update_total_weight, + name="update_avg_loss") + + update_op = tf.multiply( + 1.0 - tf.truediv(update_avg_loss, update_baseline_loss), + 100, + name="update_op") + + if metrics_collections: + tf.add_to_collections(metrics_collections, nrce_t) + + if updates_collections: + tf.add_to_collections(updates_collections, update_op) + + return nrce_t, update_op + + +def _binary_cross_entropy(pred, target, name): + return - tf.add( + target * tf.log(pred), + (1.0 - target) * tf.log(1.0 - pred), + name=name) + + +# Copied from metrics_impl.py with minor modifications. +# https://github.com/tensorflow/tensorflow/blob/v1.5.0/tensorflow/python/ops/metrics_impl.py#L39 +def _metric_variable(shape, dtype, validate_shape=True, name=None): + """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES`) collections.""" + + return tf.Variable( + lambda: tf.zeros(shape, dtype), + trainable=False, + collections=[tf.GraphKeys.LOCAL_VARIABLES, tf.GraphKeys.METRIC_VARIABLES], + validate_shape=validate_shape, + name=name) + +PERCENTILES = np.linspace(0, 1, 101, dtype=np.float32) + +# metric_name: (metric, requires thresholded output) +SUPPORTED_BINARY_CLASS_METRICS = { + # TWML metrics + 'total_weight': (total_weight_metric, False), + 'num_samples': (num_samples_metric, False), + 'rce': (rce, False), + 'rce_std_err': (partial(metric_std_err, transform=rce_transform, metric=rce_metric, name='rce_std_err'), False), + 'nrce': (partial(rce, normalize=True), False), + 'lolly_nrce': (lolly_nrce, False), + 'arce': (partial(rce, normalize=True, arce=True), False), + 'arce_original': (partial(rce, normalize=True, arce=True, up_weight=False), False), + # CTR measures positive sample ratio. This terminology is inherited from Ads. + 'ctr': (ctr, False), + # predicted CTR measures predicted positive ratio. + 'predicted_ctr': (predicted_ctr, False), + 'pred_std_dev': (prediction_std_dev, False), + # thresholded metrics + 'accuracy': (tf.metrics.accuracy, True), + 'precision': (tf.metrics.precision, True), + 'recall': (tf.metrics.recall, True), + + 'false_positives': (tf.metrics.false_positives, True), + 'false_negatives': (tf.metrics.false_negatives, True), + 'true_positives': (tf.metrics.true_positives, True), + 'true_negatives': (tf.metrics.true_negatives, True), + + 'precision_at_percentiles': (partial(tf.metrics.precision_at_thresholds, thresholds=PERCENTILES), False), + 'recall_at_percentiles': (partial(tf.metrics.recall_at_thresholds, thresholds=PERCENTILES), False), + 'false_positives_at_percentiles': (partial(tf.metrics.false_positives_at_thresholds, thresholds=PERCENTILES), False), + 'false_negatives_at_percentiles': (partial(tf.metrics.false_negatives_at_thresholds, thresholds=PERCENTILES), False), + 'true_positives_at_percentiles': (partial(tf.metrics.true_positives_at_thresholds, thresholds=PERCENTILES), False), + 'true_negatives_at_percentiles': (partial(tf.metrics.true_negatives_at_thresholds, thresholds=PERCENTILES), False), + + # tensorflow metrics + 'roc_auc': (partial(tf.metrics.auc, curve='ROC', + summation_method='careful_interpolation'), False), + 'pr_auc': (partial(tf.metrics.auc, curve='PR', + summation_method='careful_interpolation'), False), + + # tensorboard curves + 'pr_curve': (tb.summary.v1.pr_curve_streaming_op, False), + + # deprecated metrics + 'deprecated_nrce': (partial(rce, normalize=True, deprecated_rce=True), False), + 'deprecated_arce': (partial(rce, normalize=True, arce=True, deprecated_rce=True), False), + 'deprecated_arce_original': (partial(rce, normalize=True, arce=True, + up_weight=False, deprecated_rce=True), False) +} + +# default metrics provided by get_binary_class_metric_fn +DEFAULT_BINARY_CLASS_METRICS = ['total_weight', 'num_samples', 'rce', 'rce_std_err', + 'nrce', 'arce', 'ctr', 'predicted_ctr', 'pred_std_dev', + 'accuracy', 'precision', 'recall', 'roc_auc', 'pr_auc'] + + +def get_binary_class_metric_fn(metrics=None): + """ + Returns a function having signature: + + .. code-block:: python + + def get_eval_metric_ops(graph_output, labels, weights): + ... + return eval_metric_ops + + where the returned eval_metric_ops is a dict of common evaluation metric + Ops for binary classification. See `tf.estimator.EstimatorSpec + `_ + for a description of eval_metric_ops. The graph_output is a the result + dict returned by build_graph. Labels and weights are tf.Tensors. + + The following graph_output keys are recognized: + output: + the raw predictions between 0 and 1. Required. + threshold: + A value between 0 and 1 used to threshold the output into a hard_output. + Defaults to 0.5 when threshold and hard_output are missing. + Either threshold or hard_output can be provided, but not both. + hard_output: + A thresholded output. Either threshold or hard_output can be provided, but not both. + + Args: + metrics (list of String): + a list of metrics of interest. E.g. ['ctr', 'accuracy', 'rce'] + Element in the list can be a string from following supported metrics, or can be a tuple + with three items: metric name, metric function, bool for thresholded output. + + These metrics are evaluated and reported to tensorboard *during the eval phases only*. + Supported metrics: + + - ctr (same as positive sample ratio.) + - rce (cross entropy loss compared to the baseline model of always predicting ctr) + - nrce (normalized rce, do not use this one if you do not understand what it is) + - `arce `_ (a more recent proposed improvment over NRCE) + - arce_original + - lolly_nrce (NRCE as it is computed in Lolly, with Taylor expansion) + - pr_auc + - roc_auc + - accuracy (percentage of predictions that are correct) + - precision (true positives) / (true positives + false positives) + - recall (true positives) / (true positives + false negatives) + - pr_curve (precision-recall curve) + - deprecated_arce (ARCE as it was calculated before a stability fix) + - deprecated_nrce (NRCE as it was calculated before a stability fix) + + Example of metrics list with mixture of string and tuple: + metrics = [ + 'rce','nrce', + 'roc_auc', # default roc_auc metric + ( + 'roc_auc_500', # give this metric a name + partial(tf.metrics.auc, curve='ROC', summation_method='careful_interpolation', num_thresholds=500), # the metric fn + False, # whether the metric requires thresholded output + )] + + NOTE: When predicting rare events roc_auc can be underestimated. Increasing num_threshold + can reduce the underestimation. See go/roc-auc-pitfall for more details. + + NOTE: accuracy / precision / recall apply to binary classification problems only. + I.e. a prediction is only considered correct if it matches the label. E.g. if the label + is 1.0, and the prediction is 0.99, it does not get credit. If you want to use + precision / recall / accuracy metrics with soft predictions, you'll need to threshold + your predictions into hard 0/1 labels. + + When metrics is None (the default), it defaults to: + [rce, nrce, arce, ctr, predicted_ctr, accuracy, precision, recall, prauc, roc_auc], + """ + # pylint: disable=dict-keys-not-iterating + if metrics is None: + # remove expensive metrics by default for faster eval + metrics = list(DEFAULT_BINARY_CLASS_METRICS) + + def get_eval_metric_ops(graph_output, labels, weights): + """ + graph_output: + dict that is returned by build_graph given input features. + labels: + target labels associated to batch. + weights: + weights of the samples.. + """ + + eval_metric_ops = OrderedDict() + + preds = graph_output['output'] + + threshold = graph_output['threshold'] if 'threshold' in graph_output else 0.5 + + hard_preds = graph_output.get('hard_output') + if hard_preds is None: + hard_preds = tf.greater_equal(preds, threshold) + + # add metrics to eval_metric_ops dict + for metric in metrics: + if isinstance(metric, tuple) and len(metric) == 3: + metric_name, metric_factory, requires_threshold = metric + metric_name = metric_name.lower() + elif isinstance(metric, str): + metric_name = metric.lower() # metric name are case insensitive. + metric_factory, requires_threshold = SUPPORTED_BINARY_CLASS_METRICS.get(metric_name) + else: + raise ValueError("Metric should be either string or tuple of length 3.") + + if metric_name in eval_metric_ops: + # avoid adding duplicate metrics. + continue + + if metric_factory: + value_op, update_op = metric_factory( + labels=labels, + predictions=(hard_preds if requires_threshold else preds), + weights=weights, name=metric_name) + eval_metric_ops[metric_name] = (value_op, update_op) + else: + raise ValueError('Cannot find the metric named ' + metric_name) + + return eval_metric_ops + + return get_eval_metric_ops + + +def get_multi_binary_class_metric_fn(metrics, classes=None, class_dim=1): + """ + Returns a function having signature: + + .. code-block:: python + + def get_eval_metric_ops(graph_output, labels, weights): + ... + return eval_metric_ops + + where the returned eval_metric_ops is a dict of common evaluation metric + Ops for concatenated binary classifications. See `tf.estimator.EstimatorSpec + `_ + for a description of eval_metric_ops. The graph_output is a the result + dict returned by build_graph. Labels and weights are tf.Tensors. + + In multiple binary classification problems, the + ``predictions`` (that is, ``graph_output['output']``) + are expected to have shape ``batch_size x n_classes``, + where ``n_classes`` is the number of binary classification. + Binary classification at output[i] is expected to discriminate between ``classes[i]`` (1) + and NOT ``classes[i]`` (0). The labels should be of the same shape as ``graph_output`` + with binary values (0 or 1). The weights can be of size ``batch_size`` or + ``batch_size x n_classes``. The ``class_dim`` contain separate probabilities, + and need to have separate metrics. + + The following graph_output keys are recognized: + output: + the raw predictions between 0 and 1. Required. + threshold: + A value between 0 and 1 used to threshold the output into a hard_output. + Defaults to 0.5 when threshold and hard_output are missing. + Either threshold or hard_output can be provided, but not both. + hard_output: + A thresholded output. Either threshold or hard_output can be provided, but not both. + + Args: + metrics (list of Metrics): + a list of metrics of interest. E.g. ['ctr', 'accuracy', 'rce'] + Element in the list can be a string from following supported metrics, or can be a tuple + with three items: metric name, metric function, bool for thresholded output. + + These metrics are evaluated and reported to tensorboard *during the eval phases only*. + Supported metrics: + + - ctr (same as positive sample ratio.) + - rce (cross entropy loss compared to the baseline model of always predicting ctr) + - nrce (normalized rce, do not use this one if you do not understand what it is) + - pr_auc + - roc_auc + - accuracy (percentage of predictions that are correct) + - precision (true positives) / (true positives + false positives) + - recall (true positives) / (true positives + false negatives) + - pr_curve (precision-recall curve) + + Example of metrics list with mixture of string and tuple: + metrics = [ + 'rce','nrce', + 'roc_auc', # default roc_auc metric + ( + 'roc_auc_500', # give this metric a name + partial(tf.metrics.auc, curve='ROC', summation_method='careful_interpolation', num_thresholds=500), # the metric fn + False, # whether the metric requires thresholded output + )] + + NOTE: When prediction on rare events, roc_auc can be underestimated. Increase num_threshold + can reduce the underestimation. See go/roc-auc-pitfall for more details. + + NOTE: accuracy / precision / recall apply to binary classification problems only. + I.e. a prediction is only considered correct if it matches the label. E.g. if the label + is 1.0, and the prediction is 0.99, it does not get credit. If you want to use + precision / recall / accuracy metrics with soft predictions, you'll need to threshold + your predictions into hard 0/1 labels. + + When metrics is None (the default), it defaults to: + [rce, nrce, arce, ctr, predicted_ctr, accuracy, precision, recall, prauc, roc_auc], + + classes (list of strings): + In case of multiple binary class models, the names for each class or label. + These are used to display metrics on tensorboard. + If these are not specified, the index in the class or label dimension is used, and you'll + get metrics on tensorboard named like: accuracy_0, accuracy_1, etc. + + class_dim (number): + Dimension of the classes in predictions. Defaults to 1, that is, batch_size x n_classes. + """ + # pylint: disable=invalid-name,dict-keys-not-iterating + if metrics is None: + # remove expensive metrics by default for faster eval + metrics = list(DEFAULT_BINARY_CLASS_METRICS) + + def get_eval_metric_ops(graph_output, labels, weights): + """ + graph_output: + dict that is returned by build_graph given input features. + labels: + target labels associated to batch. + weights: + weights of the samples.. + """ + + eval_metric_ops = OrderedDict() + + preds = graph_output['output'] + + threshold = graph_output['threshold'] if 'threshold' in graph_output else 0.5 + + hard_preds = graph_output.get('hard_output') + if hard_preds is None: + hard_preds = tf.greater_equal(preds, threshold) + + shape = labels.get_shape() + # basic sanity check: multi_metric dimension must exist + assert len(shape) > class_dim, "Dimension specified by class_dim does not exist." + + num_labels = shape[class_dim] + # If we are doing multi-class / multi-label metric, the number of classes / labels must + # be know at graph construction time. This dimension cannot have size None. + assert num_labels is not None, "The multi-metric dimension cannot be None." + assert classes is None or len(classes) == num_labels, ( + "Number of classes must match the number of labels") + + weights_shape = weights.get_shape() if weights is not None else None + if weights_shape is None: + num_weights = None + elif len(weights_shape) > 1: + num_weights = weights_shape[class_dim] + else: + num_weights = 1 + + for i in range(num_labels): + + # add metrics to eval_metric_ops dict + for metric in metrics: + if isinstance(metric, tuple) and len(metric) == 3: + metric_name, metric_factory, requires_threshold = metric + metric_name = metric_name.lower() + elif isinstance(metric, str): + metric_name = metric.lower() # metric name are case insensitive. + metric_factory, requires_threshold = SUPPORTED_BINARY_CLASS_METRICS.get(metric_name) + else: + raise ValueError("Metric should be either string or tuple of length 3.") + + class_metric_name = metric_name + "_" + (classes[i] if classes is not None else str(i)) + + if class_metric_name in eval_metric_ops: + # avoid adding duplicate metrics. + continue + + class_labels = tf.gather(labels, indices=[i], axis=class_dim) + class_preds = tf.gather(preds, indices=[i], axis=class_dim) + class_hard_preds = tf.gather(hard_preds, indices=[i], axis=class_dim) + + if num_weights is None: + class_weights = None + elif num_weights == num_labels: + class_weights = tf.gather(weights, indices=[i], axis=class_dim) + elif num_weights == 1: + class_weights = weights + else: + raise ValueError("num_weights (%d) and num_labels (%d) do not match" + % (num_weights, num_labels)) + + if metric_factory: + value_op, update_op = metric_factory( + labels=class_labels, + predictions=(class_hard_preds if requires_threshold else class_preds), + weights=class_weights, name=class_metric_name) + eval_metric_ops[class_metric_name] = (value_op, update_op) + else: + raise ValueError('Cannot find the metric named ' + metric_name) + + return eval_metric_ops + + return get_eval_metric_ops + + +def _get_uncalibrated_metric_fn(calibrated_metric_fn, keep_weight=True): + """ + Returns a function having signature: + + .. code-block:: python + + def get_eval_metric_ops(graph_output, labels, weights): + ... + return eval_metric_ops + + where the returned eval_metric_ops is a dict of common evaluation metric + Ops with uncalibrated output. + + The following graph_output keys are recognized: + uncalibrated_output: + the uncalibrated raw predictions between 0 and 1. Required. + output: + the calibrated predictions between 0 and 1. + threshold: + A value between 0 and 1 used to threshold the output into a hard_output. + Defaults to 0.5 when threshold and hard_output are missing. + Either threshold or hard_output can be provided, but not both. + hard_output: + A thresholded output. Either threshold or hard_output can be provided, but not both. + + Args: + calibrated_metric_fn: metrics function with calibration and weight. + keep_weight: Bool indicating whether we keep weight. + """ + metric_scope = 'uncalibrated' if keep_weight else 'unweighted' + + def get_eval_metric_ops(graph_output, labels, weights): + """ + graph_output: + dict that is returned by build_graph given input features. + labels: + target labels associated to batch. + weights: + weights of the samples.. + """ + with tf.variable_scope(metric_scope): + if 'uncalibrated_output' not in graph_output: + raise Exception("Missing uncalibrated_output in graph_output!") + un_calibrated_weights = weights if keep_weight else tf.ones_like(weights) + uncalibrated_output = { + 'output': graph_output['uncalibrated_output'], + 'threshold': graph_output.get('threshold', 0.5), + 'hard_output': graph_output.get('hard_output'), + **{k: v for k, v in graph_output.items() if k not in ['output', 'threshold', 'hard_output']} + } + + eval_metrics_ops = calibrated_metric_fn(uncalibrated_output, labels, un_calibrated_weights) + + renamed_metrics_ops = {f'{metric_scope}_{k}': v for k, v in eval_metrics_ops.items()} + return renamed_metrics_ops + + return get_eval_metric_ops + + +def get_multi_binary_class_uncalibrated_metric_fn( + metrics, classes=None, class_dim=1, keep_weight=True): + """ + Returns a function having signature: + + .. code-block:: python + + def get_eval_metric_ops(graph_output, labels, weights): + ... + return eval_metric_ops + + where the returned eval_metric_ops is a dict of common evaluation metric + Ops for concatenated binary classifications without calibration. + + Note: 'uncalibrated_output' is required key in graph_output. + + The main use case for this function is: + + 1) To calculated roc-auc for rare event. + Calibrated prediction score for rare events will be concentrated near zero. As a result, + the roc-auc can be seriously underestimated with current implementation in tf.metric.auc. + Since roc-auc is invariant against calibration, we can directly use uncalibrated score for roc-auc. + For more details, please refer to: go/roc-auc-invariance. + + 2) To set keep_weight=False and get unweighted and uncalibrated metrics. + This is useful to eval how the model is fitted to its actual training data, since + often time the model is trained without weight. + + Args: + metrics (list of String): + a list of metrics of interest. E.g. ['ctr', 'accuracy', 'rce'] + Element in the list can be a string from supported metrics, or can be a tuple + with three items: metric name, metric function, bool for thresholded output. + These metrics are evaluated and reported to tensorboard *during the eval phases only*. + + When metrics is None (the default), it defaults to: + [rce, nrce, arce, ctr, predicted_ctr, accuracy, precision, recall, prauc, roc_auc], + + classes (list of strings): + In case of multiple binary class models, the names for each class or label. + These are used to display metrics on tensorboard. + If these are not specified, the index in the class or label dimension is used, and you'll + get metrics on tensorboard named like: accuracy_0, accuracy_1, etc. + + class_dim (number): + Dimension of the classes in predictions. Defaults to 1, that is, batch_size x n_classes. + + keep_weight (bool): + Whether to keep weights for the metric. + """ + + calibrated_metric_fn = get_multi_binary_class_metric_fn( + metrics, classes=classes, class_dim=class_dim) + return _get_uncalibrated_metric_fn(calibrated_metric_fn, keep_weight=keep_weight) + + +def combine_metric_fns(*fn_list): + """ + Combine multiple metric functions. + For example, we can combine metrics function generated by + get_multi_binary_class_metric_fn and get_multi_binary_class_uncalibrated_metric_fn. + + Args: + *fn_list: Multiple metric functions to be combined + + Returns: + Combined metric function. + """ + def combined_metric_ops(*args, **kwargs): + eval_metric_ops = OrderedDict() + for fn in fn_list: + eval_metric_ops.update(fn(*args, **kwargs)) + return eval_metric_ops + return combined_metric_ops diff --git a/twml/twml/optimizers/__init__.py b/twml/twml/optimizers/__init__.py new file mode 100644 index 000000000..eaa29883c --- /dev/null +++ b/twml/twml/optimizers/__init__.py @@ -0,0 +1,4 @@ +from twitter.deepbird.compat.v1.optimizers import ( + LazyAdamOptimizer, + optimize_loss, + OPTIMIZER_SUMMARIES) # noqa: F401 diff --git a/twml/twml/parsers.py b/twml/twml/parsers.py new file mode 100644 index 000000000..eac60083a --- /dev/null +++ b/twml/twml/parsers.py @@ -0,0 +1,20 @@ +''' +Contains implementations of functions to parse training and evaluation data. + +Modelers can use the functions in this module as the the train/eval_parse_fn of +the DataRecordTrainer constructor to customize how to parse their datasets. + +Modelers may also provide custom implementations of train/eval_parse_fn using these as reference. +''' + +from twitter.deepbird.io.legacy.parsers import ( + convert_to_supervised_input_receiver_fn, # noqa: F401 + get_continuous_parse_fn, # noqa: F401 + get_default_parse_fn, # noqa: F401 + get_features_as_tensor_dict, # noqa: F401 + get_labels_in_features_parse_fn, # noqa: F401 + get_serving_input_receiver_fn_feature_dict, # noqa: F401 + get_sparse_parse_fn, # noqa: F401 + get_sparse_serving_input_receiver_fn, # noqa: F401 + get_tensor_parse_fn, # noqa: F401 +) diff --git a/twml/twml/readers/__init__.py b/twml/twml/readers/__init__.py new file mode 100644 index 000000000..06a6d79f5 --- /dev/null +++ b/twml/twml/readers/__init__.py @@ -0,0 +1,7 @@ +# pylint: disable=wildcard-import +""" This module contains data readers """ + +from .batch_prediction_request import BatchPredictionRequest # noqa: F401 +from .data_record import DataRecord, SPARSE_DATA_RECORD_FEATURE_FIELDS # noqa: F401 +from .hashed_batch_prediction_request import HashedBatchPredictionRequest # noqa: F401 +from .hashed_data_record import HashedDataRecord # noqa: F401 \ No newline at end of file diff --git a/twml/twml/readers/batch_prediction_request.py b/twml/twml/readers/batch_prediction_request.py new file mode 100644 index 000000000..512a8c514 --- /dev/null +++ b/twml/twml/readers/batch_prediction_request.py @@ -0,0 +1,8 @@ +# pylint: disable=invalid-name +""" +This module implements the reader for BatchPredictionRequest. +""" + +from twitter.deepbird.io.legacy.readers.batch_prediction_request import ( + BatchPredictionRequest # noqa: F401 +) diff --git a/twml/twml/readers/data_record.py b/twml/twml/readers/data_record.py new file mode 100644 index 000000000..d1c377afd --- /dev/null +++ b/twml/twml/readers/data_record.py @@ -0,0 +1,15 @@ +# pylint: disable=invalid-name +""" +This module includes facilities for manipulating data records. +""" + +from twitter.deepbird.io.legacy.readers.data_record import ( + _SPEC_TO_TF, # noqa: F401 + SPARSE_DATA_RECORD_FEATURE_FIELDS, # noqa: F401 + _FeaturesBase, # noqa: F401 + _Features, # noqa: F401 + _DiscreteFeatures, # noqa: F401 + _StringFeatures, # noqa: F401 + _BaseDataRecord, # noqa: F401 + DataRecord, # noqa: F401 +) diff --git a/twml/twml/readers/hashed_batch_prediction_request.py b/twml/twml/readers/hashed_batch_prediction_request.py new file mode 100644 index 000000000..5850c4497 --- /dev/null +++ b/twml/twml/readers/hashed_batch_prediction_request.py @@ -0,0 +1,8 @@ +# pylint: disable=invalid-name +""" +This module implements the reader for HashedBatchPredictionRequest. +""" + +from twitter.deepbird.io.legacy.readers.hashed_batch_prediction_request import ( + HashedBatchPredictionRequest # noqa: F401 +) diff --git a/twml/twml/readers/hashed_data_record.py b/twml/twml/readers/hashed_data_record.py new file mode 100644 index 000000000..1ff9ce816 --- /dev/null +++ b/twml/twml/readers/hashed_data_record.py @@ -0,0 +1,12 @@ +# checkstyle: noqa +# pylint: disable=invalid-name +""" +This module includes facilities for manipulating hashed data records. +""" + +from twitter.deepbird.io.legacy.readers.hashed_data_record import ( + _HASHED_FIELDS, + _FEATURE_NAMES, + _FEATURE_TYPES, + HashedDataRecord, +) diff --git a/twml/twml/saved_model_cli/__init__.py b/twml/twml/saved_model_cli/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/twml/twml/saved_model_cli/__main__.py b/twml/twml/saved_model_cli/__main__.py new file mode 100644 index 000000000..ad5326431 --- /dev/null +++ b/twml/twml/saved_model_cli/__main__.py @@ -0,0 +1,9 @@ +""" +This module is responsible for running saved_model_cli. +""" +import sys + +from tensorflow.python.tools import saved_model_cli + +if __name__ == '__main__': + sys.exit(saved_model_cli.main()) diff --git a/twml/twml/summary/__init__.py b/twml/twml/summary/__init__.py new file mode 100644 index 000000000..284d7cf3f --- /dev/null +++ b/twml/twml/summary/__init__.py @@ -0,0 +1,6 @@ +from tensorflow.python.ops.summary_ops_v2 import flush # noqa: F401 + +""" +NOTE: Using `from tensorflow.python.ops.summary_ops_v2 import flush` in the code works. +This stub exists because it was easier to refactor code because twml is widely used. +""" diff --git a/twml/twml/tensorboard/__init__.py b/twml/twml/tensorboard/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/twml/twml/tensorboard/__main__.py b/twml/twml/tensorboard/__main__.py new file mode 100644 index 000000000..c426060d1 --- /dev/null +++ b/twml/twml/tensorboard/__main__.py @@ -0,0 +1,16 @@ +""" +This module is responsible for running tensorboard. +""" +import logging +import re +import sys + +from tensorboard.main import run_main + + +if __name__ == '__main__': + # Tensorboard relies on werkzeug for its HTTP server which logs at info level + # by default + logging.getLogger('werkzeug').setLevel(logging.WARNING) + sys.argv[0] = re.sub(r'(-script\.pyw?|\.exe)?$', '', sys.argv[0]) + sys.exit(run_main()) diff --git a/twml/twml/tensorio.py b/twml/twml/tensorio.py new file mode 100644 index 000000000..bc551ac56 --- /dev/null +++ b/twml/twml/tensorio.py @@ -0,0 +1,161 @@ +# pylint: disable=missing-docstring, bare-except, pointless-statement, +# pointless-string-statement, redundant-unittest-assert, no-else-return, +# no-member, old-style-class, dangerous-default-value, protected-access, +# too-few-public-methods + +import os + +import numpy as np +import yaml + + +""" +Utility to load tensors serialized by Deepbird V1. + +Note that Deepbird V1 serialize tensor names as \"weight\".\'1\'. +For user-friendliness, the quotes are removed from the tensor names. +""" + + +# helper class used to assist hierarchical key access by remembering intermediate keys. +class _KeyRecorder(object): + def __init__(self, tensorio, keys=[]): + self.tensorio = tensorio + self.keys = keys + + def __getitem__(self, k): + new_keys = self.keys + [str(k)] + prefix = ".".join(new_keys) + + key_list = self.tensorio.list_tensors() + + # if we have a complete key, load the tensor. + if prefix in key_list: + return self.tensorio._load(prefix) + + # we don't have a complete key yet, but at least one tensor should start with this prefix. + for k_value in key_list: + if k_value.startswith(prefix): + return _KeyRecorder(self.tensorio, new_keys) + + # if no key starts with the prefix, this _key_recorder is not valid. + raise ValueError("Key not found: " + prefix) + + +# convert tensorio tensor type to numpy data type. +# also returns element size in bytes. +def _get_data_type(data_type): + if data_type == 'Double': + return (np.float64, 8) + + if data_type == 'Float': + return (np.float32, 4) + + if data_type == 'Int': + return (np.int32, 4) + + if data_type == 'Long': + return (np.int64, 8) + + if data_type == 'Byte': + return (np.int8, 1) + + raise ValueError('Unexpected tensorio data type: ' + data_type) + + +class TensorIO(object): + """ + Construct a TensorIO class. + tensorio_path: a directory containing tensors serialized using tensorio. tar file not supported. + mmap_tensor: + By default, loaded tensors use mmap storage. + Set this to false to not use mmap. Useful when loading multiple tensors. + """ + + def __init__(self, tensorio_path, mmap_tensor=True): + self._tensorio_path = tensorio_path + self._mmap_tensor = mmap_tensor + + # Make sure we can locate spec.yaml. + yaml_file = os.path.join(tensorio_path, 'spec.yaml') + if not os.path.exists(yaml_file): + raise ValueError('Invalid tensorio path: no spec.yaml found.') + + # load spec.yaml. + with open(yaml_file, 'r') as file_open: + # Note that tensor names in the yaml are like this: \"weight\".\'1\' + # For user-friendliness, we remove the quotes. + _spec = yaml.safe_load(file_open) + self._spec = {k.replace("'", '').replace('"', ''): v for (k, v) in _spec.items()} + + def list_tensors(self): + """ + Returns a list of tensors saved in the given path. + """ + return self._spec.keys() + + def _load_tensor(self, name): + """ + Load Tensor with the given name. + Raise value error if the named tensor is not found. + Returns a numpy array if the named tensor is found. + """ + tensor_info = self._spec[name] + if tensor_info['type'] != 'tensor': + raise ValueError('Trying to load a tensor of unknown type: ' + tensor_info['type']) + + filename = os.path.join(self._tensorio_path, tensor_info['filename']) + (data_type, element_size) = _get_data_type(tensor_info['tensorType']) + + np_array = np.memmap( + filename, + dtype=data_type, + mode='r', + # -1 because lua offset is 1 based. + offset=(tensor_info['offset'] - 1) * element_size, + shape=tuple(tensor_info['size']), + order='C', + ) + + return np_array if self._mmap_tensor else np_array[:].copy() + + def _load_nontensor_data(self, name): + """ + Load non-tensor data with the given name. + Returns a python string. + """ + tensor_info = self._spec[name] + return tensor_info['data'] + + def _load(self, name): + """ + Load data serialized under the given name, it could be a tensor or regular data. + """ + if name not in self._spec: + raise ValueError('The specified key {} is not found in {}'.format(name, self._tensorio_path)) + + data_type = self._spec[name]['type'] + if data_type == 'tensor': + return self._load_tensor(name) + else: + return self._load_nontensor_data(name) + + def load_all(self): + """ + Load all tensors stored in the tensorio directory. + Returns a dictionary from tensor name to numpy arrays. + """ + return {k: self._load(k) for k in self._spec} + + ########################################### + # The below are utilities for convenience # + ########################################### + def __getitem__(self, k): + """ + Shorthand for _load_tensor, but also supports hierarchical access like: tensorio['a']['b']['1'] + """ + if k in self._spec: + # We have a full tensor name, directly load it. + return self._load_tensor(k) + else: + return _KeyRecorder(self)[k] diff --git a/twml/twml/tracking/__init__.py b/twml/twml/tracking/__init__.py new file mode 100644 index 000000000..008a59f70 --- /dev/null +++ b/twml/twml/tracking/__init__.py @@ -0,0 +1,5 @@ +""" +This module contains the ExperimentTracker class. +""" + +from .experiment_tracker import ExperimentTracker # noqa: F401 diff --git a/twml/twml/tracking/experiment_tracker.py b/twml/twml/tracking/experiment_tracker.py new file mode 100644 index 000000000..4f275ba4b --- /dev/null +++ b/twml/twml/tracking/experiment_tracker.py @@ -0,0 +1,543 @@ +""" +This module contains the experiment tracker for tracking training in ML Metastore +""" +from contextlib import contextmanager +from datetime import datetime +import getpass +import hashlib +import os +import re +import sys +import time + +from absl import logging +import tensorflow.compat.v1 as tf +from twml.hooks import MetricsUpdateHook + + +try: + from urllib import quote as encode_url +except ImportError: + from urllib.parse import quote as encode_url + + +try: + # ML Metastore packages might not be available on GCP. + # If they are not found, tracking is disabled + import requests + from com.twitter.mlmetastore.modelrepo.client import ModelRepoClient + from com.twitter.mlmetastore.modelrepo.core.path import ( + check_valid_id, get_components_from_id, generate_id) + from com.twitter.mlmetastore.modelrepo.core import ( + DeepbirdRun, Experiment, FeatureConfig, FeatureConfigFeature, Model, ProgressReport, Project, StatusUpdate) +except ImportError: + ModelRepoClient = None + + +class ExperimentTracker(object): + """ + A tracker that records twml runs in ML Metastore. + """ + + def __init__(self, params, run_config, save_dir): + """ + + Args: + params (python dict): + The trainer params. ExperimentTracker uses `params.experiment_tracking_path` (String) and + `params.disable_experiment_tracking`. + If `experiment_tracking_path` is set to None, the tracker tries to guess a path with + save_dir. + If `disable_experiment_tracking` is True, the tracker is disabled. + run_config (tf.estimator.RunConfig): + The run config used by the estimator. + save_dir (str): + save_dir of the trainer + """ + if isinstance(params, dict): + self._params = params + else: + # preserving backward compatibility for people still using HParams + logging.warning("Please stop using HParams and use python dicts. HParams are removed in TF 2") + self._params = dict((k, v) for k, v in params.values().items() if v != 'null') + self._run_config = run_config + self._graceful_shutdown_port = self._params.get('health_port') + + self.tracking_path = self._params.get('experiment_tracking_path') + is_tracking_path_too_long = self.tracking_path is not None and len(self.tracking_path) > 256 + + if is_tracking_path_too_long: + raise ValueError("Experiment Tracking Path longer than 256 characters") + + self.disabled = ( + self._params.get('disable_experiment_tracking', False) or + not self._is_env_eligible_for_tracking() or + ModelRepoClient is None + ) + + self._is_hogwild = bool(os.environ.get('TWML_HOGWILD_PORTS')) + + self._is_distributed = bool(os.environ.get('TF_CONFIG')) + + self._client = None if self.disabled else ModelRepoClient() + + run_name_from_environ = self.run_name_from_environ() + run_name_can_be_inferred = ( + self.tracking_path is not None or run_name_from_environ is not None) + + # Turn the flags off as needed in hogwild / distributed + if self._is_hogwild or self._is_distributed: + self._env_eligible_for_recording_experiment = ( + self._run_config.task_type == "evaluator") + if run_name_can_be_inferred: + self._env_eligible_for_recording_export_metadata = ( + self._run_config.task_type == "chief") + else: + logging.info( + 'experiment_tracking_path is not set and can not be inferred. ' + 'Recording export metadata is disabled because the chief node and eval node ' + 'are setting different experiment tracking paths.') + self._env_eligible_for_recording_export_metadata = False + else: + # Defaults to True + self._env_eligible_for_recording_experiment = True + self._env_eligible_for_recording_export_metadata = True + + if not self.disabled: + # Sanitize passed in experiment tracking paths. e.g. own:proJ:exp:Run.Name + # -> own:proj:exp:Run_Name + if self.tracking_path: + try: + check_valid_id(self.tracking_path) + except ValueError as err: + logging.error(f'Invalid experiment tracking path provided. Sanitizing: {self.tracking_path}\nError: {err}') + self.tracking_path = generate_id( + owner=self.path['owner'], + project_name=self.path['project_name'], + experiment_name=self.path['experiment_name'], + run_name=self.path['run_name'] + ) + logging.error(f'Generated sanitized experiment tracking path: {self.tracking_path}') + else: + logging.info( + 'No experiment_tracking_path set. Experiment Tracker will try to guess a path') + self.tracking_path = self.guess_path(save_dir, run_name_from_environ) + logging.info('Guessed path: %s', self.tracking_path) + + # additional check to see if generated path is valid + try: + check_valid_id(self.tracking_path) + except ValueError as err: + logging.error( + 'Could not generate valid experiment tracking path. Disabling tracking. ' + + 'Error:\n{}'.format(err) + ) + self.disabled = True + + self.project_id = None if self.disabled else '{}:{}'.format( + self.path['owner'], self.path['project_name']) + self.base_run_id = None if self.disabled else self.tracking_path + self._current_run_name_suffix = None + + self._current_tracker_hook = None + + if self.disabled: + logging.info('Experiment Tracker is disabled') + else: + logging.info('Experiment Tracker initialized with base run id: %s', self.base_run_id) + + @contextmanager + def track_experiment(self, eval_hooks, get_estimator_spec_fn, name=None): + """ + A context manager for tracking experiment. It should wrap the training loop. + An experiment tracker eval hook is appended to eval_hooks to collect metrics. + + Args: + eval_hooks (list): + The list of eval_hooks to be used. When it's not None, and does not contain any , + MetricsUpdateHook an experiment tracker eval hook is appended to it. When it contains + any MetricsUpdateHook, this tracker is disabled to avoid conflict with legacy Model Repo + tracker (`TrackRun`). + get_estimator_spec_fn (func): + A function to get the current EstimatorSpec of the trainer, used by the eval hook. + name (str); + Name of this training or evaluation. Used as a suffix of the run_id. + + Returns: + The tracker's eval hook which is appended to eval_hooks. + """ + + # disable this tracker if legacy TrackRun hook is present + # TODO: remove this once we completely deprecate the old TrackRun interface + if eval_hooks is not None: + self.disabled = self.disabled or any(isinstance(x, MetricsUpdateHook) for x in eval_hooks) + + logging.info('Is environment eligible for recording experiment: %s', + self._env_eligible_for_recording_experiment) + + if self._env_eligible_for_recording_experiment and self._graceful_shutdown_port: + requests.post('http://localhost:{}/track_training_start'.format( + self._graceful_shutdown_port + )) + + if self.disabled or eval_hooks is None: + yield None + else: + assert self._current_tracker_hook is None, 'experiment tracking has been started already' + + if name is not None: + self._current_run_name_suffix = '_' + name + + logging.info('Starting experiment tracking. Path: %s', self._current_run_id) + logging.info('Is environment eligible for recording export metadata: %s', + self._env_eligible_for_recording_export_metadata) + logging.info('This run will be available at: http://go/mldash/experiments/%s', + encode_url(self.experiment_id)) + + try: + self._record_run() + self._add_run_status(StatusUpdate(self._current_run_id, status='RUNNING')) + self._register_for_graceful_shutdown() + + self._current_tracker_hook = self.create_eval_hook(get_estimator_spec_fn) + except Exception as err: + logging.error( + 'Failed to record run. This experiment will not be tracked. Error: %s', str(err)) + self._current_tracker_hook = None + + if self._current_tracker_hook is None: + yield None + else: + try: + eval_hooks.append(self._current_tracker_hook) + yield self._current_tracker_hook + except Exception as err: + self._add_run_status( + StatusUpdate(self._current_run_id, status='FAILED', description=str(err))) + self._deregister_for_graceful_shutdown() + self._current_tracker_hook = None + self._current_run_name_suffix = None + logging.error('Experiment tracking done. Experiment failed.') + raise + + try: + if self._current_tracker_hook.metric_values: + self._record_update(self._current_tracker_hook.metric_values) + self._add_run_status(StatusUpdate(self._current_run_id, status='SUCCESS')) + logging.info('Experiment tracking done. Experiment succeeded.') + except Exception as err: + logging.error( + 'Failed to update mark run as successful. Error: %s', str(err)) + finally: + self._deregister_for_graceful_shutdown() + self._current_tracker_hook = None + self._current_run_name_suffix = None + + def create_eval_hook(self, get_estimator_spec_fn): + """ + Create an eval_hook to track eval metrics + + Args: + get_estimator_spec_fn (func): + A function that returns the current EstimatorSpec of the trainer. + """ + return MetricsUpdateHook( + get_estimator_spec_fn=get_estimator_spec_fn, + add_metrics_fn=self._record_update) + + def register_model(self, export_path): + """ + Record the exported model. + + Args: + export_path (str): + The path to the exported model. + """ + if self.disabled: + return None + + try: + logging.info('Model is exported to %s. Computing hash of the model.', export_path) + model_hash = self.compute_model_hash(export_path) + logging.info('Model hash: %s. Registering it in ML Metastore.', model_hash) + self._client.register_model(Model(model_hash, self.path['owner'], self.base_run_id)) + except Exception as err: + logging.error('Failed to register model. Error: %s', str(err)) + + def export_feature_spec(self, feature_spec_dict): + """ + Export feature spec to ML Metastore (go/ml-metastore). + + Please note that the feature list in FeatureConfig only keeps the list of feature hash ids due + to the 1mb upper limit for values in manhattan, and more specific information (feature type, + feature name) for each feature config feature is stored separately in FeatureConfigFeature dataset. + + Args: + feature_spec_dict (dict): A dictionary obtained from FeatureConfig.get_feature_spec() + """ + if self.disabled or not self._env_eligible_for_recording_export_metadata: + return None + + try: + logging.info('Exporting feature spec to ML Metastore.') + feature_list = feature_spec_dict['features'] + label_list = feature_spec_dict['labels'] + weight_list = feature_spec_dict['weight'] + self._client.add_feature_config(FeatureConfig(self._current_run_id, list(feature_list.keys()), + list(label_list.keys()), list(weight_list.keys()))) + + feature_config_features = [ + FeatureConfigFeature( + hash_id=_feature_hash_id, + feature_name=_feature['featureName'], + feature_type=_feature['featureType'] + ) + for _feature_hash_id, _feature in zip(feature_list.keys(), feature_list.values()) + ] + self._client.add_feature_config_features(list(feature_list.keys()), feature_config_features) + + feature_config_labels = [ + FeatureConfigFeature( + hash_id=_label_hash_id, + feature_name=_label['featureName'] + ) + for _label_hash_id, _label in zip(label_list.keys(), label_list.values()) + ] + self._client.add_feature_config_features(list(label_list.keys()), feature_config_labels) + + feature_config_weights = [ + FeatureConfigFeature( + hash_id=_weight_hash_id, + feature_name=_weight['featureName'], + feature_type=_weight['featureType'] + ) + for _weight_hash_id, _weight in zip(weight_list.keys(), weight_list.values()) + ] + self._client.add_feature_config_features(list(weight_list.keys()), feature_config_weights) + + except Exception as err: + logging.error('Failed to export feature spec. Error: %s', str(err)) + + @property + def path(self): + if self.disabled: + return None + return get_components_from_id(self.tracking_path, ensure_valid_id=False) + + @property + def experiment_id(self): + if self.disabled: + return None + return '%s:%s:%s' % (self.path['owner'], self.path['project_name'], + self.path['experiment_name']) + + @property + def _current_run_name(self): + """ + Return the current run name. + """ + if self._current_run_name_suffix is not None: + return self.path['run_name'] + self._current_run_name_suffix + else: + return self.path['run_name'] + + @property + def _current_run_id(self): + """ + Return the current run id. + """ + if self._current_run_name_suffix is not None: + return self.base_run_id + self._current_run_name_suffix + else: + return self.base_run_id + + def get_run_status(self) -> str: + if not self.disabled: + return self._client.get_latest_dbv2_status(self._current_run_id) + + def _add_run_status(self, status): + """ + Add run status with underlying client. + + Args: + status (StatusUpdate): + The status update to add. + """ + if not self.disabled and self._env_eligible_for_recording_experiment: + self._client.add_run_status(status) + + def _record_run(self): + """ + Record the run in ML Metastore. + """ + if self.disabled or not self._env_eligible_for_recording_experiment: + return None + + if not self._client.project_exists(self.project_id): + self._client.add_project(Project(self.path['project_name'], self.path['owner'])) + time.sleep(1) + + if not self._client.experiment_exists(self.experiment_id): + self._client.add_experiment(Experiment( + self.path['experiment_name'], self.path['owner'], self.project_id, '')) + time.sleep(1) + + run = DeepbirdRun(self.experiment_id, self._current_run_name, '', + {'raw_command': ' '.join(sys.argv)}, self._params) + self._client.add_deepbird_run(run, force=True) + time.sleep(1) + + def _record_update(self, metrics): + """ + Record metrics update in ML Metastore. + + Args: + metrics (dict): + The dict of the metrics and their values. + """ + + if self.disabled or not self._env_eligible_for_recording_experiment: + return None + + reported_metrics = {} + for k, v in metrics.items(): + + if hasattr(v, 'item'): + reported_metrics[k] = v.item() if v.size == 1 else str(v.tolist()) + else: + logging.warning("Ignoring %s because the value (%s) is not valid" % (k, str(v))) + + report = ProgressReport(self._current_run_id, reported_metrics) + + try: + self._client.add_progress_report(report) + except Exception as err: + logging.error('Failed to record metrics in ML Metastore. Error: {}'.format(err)) + logging.error('Run ID: {}'.format(self._current_run_id)) + logging.error('Progress Report: {}'.format(report.to_json_string())) + + def _register_for_graceful_shutdown(self): + """ + Register the tracker with the health server, enabling graceful shutdown. + + Returns: + (Response) health server response + """ + if self._graceful_shutdown_port and not self.disabled and self._env_eligible_for_recording_experiment: + return requests.post('http://localhost:{}/register_id/{}'.format( + self._graceful_shutdown_port, + self._current_run_id + )) + + def _deregister_for_graceful_shutdown(self): + """ + Deregister the tracker with the health server, disabling graceful shutdown. + + Returns: + (Response) health server response + """ + if self._graceful_shutdown_port and not self.disabled and self._env_eligible_for_recording_experiment: + return requests.post('http://localhost:{}/deregister_id/{}'.format( + self._graceful_shutdown_port, + self._current_run_id + )) + + def _is_env_eligible_for_tracking(self): + """ + Determine if experiment tracking should run in the env. + """ + is_unit_test = ( + os.environ.get('PYTEST_CURRENT_TEST') is not None and + os.environ.get('TEST_EXP_TRACKER') is None + ) + + is_running_on_ci = ( + getpass.getuser() == 'scoot-service' and + os.environ.get('TEST_EXP_TRACKER') is None + ) + + return ( + not is_unit_test and + not is_running_on_ci + ) + + @classmethod + def run_name_from_environ(cls): + """ + Create run id from environment if possible. + """ + job_name = os.environ.get("TWML_JOB_NAME") + job_launch_time = os.environ.get("TWML_JOB_LAUNCH_TIME") + + if not job_name or not job_launch_time: + return None + + try: + # job_launch_time should be in isoformat + # python2 doesnt support datetime.fromisoformat, so use hardcoded format string. + job_launch_time_formatted = datetime.strptime(job_launch_time, + "%Y-%m-%dT%H:%M:%S.%f") + except ValueError: + # Fallback in case aurora config is generating datetime in a different format. + job_launch_time_formatted = (job_launch_time + .replace("-", "_").replace("T", "_") + .replace(":", "_").replace(".", "_")) + + return '{}_{}'.format( + job_name, job_launch_time_formatted.strftime('%m_%d_%Y_%I_%M_%p')) + + @classmethod + def guess_path(cls, save_dir, run_name=None): + """ + Guess an experiment tracking path based on save_dir. + + Returns: + (str) guessed path + """ + if not run_name: + run_name = 'Unnamed_{}'.format(datetime.now().strftime('%m_%d_%Y_%I_%M_%p')) + + if save_dir.startswith('hdfs://'): + path_match = re.search(r'/user/([a-z0-9\-_]+)/([a-z0-9\-_]+)', save_dir) + + if path_match: + groups = path_match.groups() + user = groups[0] + project_name = groups[1] + + return generate_id(user, 'default', project_name, run_name) + + user = getpass.getuser() + project_name = re.sub(r'^[a-z0-9\-_]', os.path.basename(save_dir), '') + if not project_name: + project_name = 'unnamed' + + return generate_id(user, 'default', project_name, run_name) + + @classmethod + def compute_model_hash(cls, export_path): + """ + Computes the hash of an exported model. This is a gfile version of + twitter.mlmetastore.common.versioning.compute_hash. The two functions should generate + the same hash when given the same model. + + Args: + export_path (str): + The path to the exported model. + + Returns: + (str) hash of the exported model + """ + paths = [] + for path, subdirs, files in tf.io.gfile.walk(export_path): + for name in sorted(files): + paths.append(os.path.join(path, name)) + + paths.sort() + hash_object = hashlib.new('sha1') + + for path in paths: + with tf.io.gfile.GFile(path, "rb") as file: + hash_object.update(file.read()) + + return hash_object.hexdigest() diff --git a/twml/twml/trainers/__init__.py b/twml/twml/trainers/__init__.py new file mode 100644 index 000000000..e6664d9a6 --- /dev/null +++ b/twml/twml/trainers/__init__.py @@ -0,0 +1,10 @@ +# pylint: disable=wildcard-import +""" +This module contains the Trainer and DataRecordTrainer classes. +Trainers wrap a +`tf.estimator.Estimator +`_. +""" + +from .trainer import Trainer # noqa: F401 +from .data_record_trainer import DataRecordTrainer # noqa: F401 diff --git a/twml/twml/trainers/data_record_trainer.py b/twml/twml/trainers/data_record_trainer.py new file mode 100644 index 000000000..76dd16f80 --- /dev/null +++ b/twml/twml/trainers/data_record_trainer.py @@ -0,0 +1,821 @@ +# pylint: disable=arguments-differ, invalid-name +""" +This module contains the ``DataRecordTrainer``. +Unlike the parent ``Trainer`` class, the ``DataRecordTrainer`` +is used specifically for processing data records. +It abstracts away a lot of the intricacies of working with DataRecords. +`DataRecord `_ is the main piping format for data samples. +The `DataRecordTrainer` assumes training data and production responses and requests +to be organized as the `Thrift prediction service API + +A ``DataRecord`` is a Thrift struct that defines how to encode the data: + +:: + + struct DataRecord { + 1: optional set binaryFeatures; // stores BINARY features + 2: optional map continuousFeatures; // stores CONTINUOUS features + 3: optional map discreteFeatures; // stores DISCRETE features + 4: optional map stringFeatures; // stores STRING features + 5: optional map> sparseBinaryFeatures; // stores sparse BINARY features + 6: optional map> sparseContinuousFeatures; // sparse CONTINUOUS feature + 7: optional map blobFeatures; // stores features as BLOBs (binary large objects) + 8: optional map tensors; // stores TENSOR features + 9: optional map sparseTensors; // stores SPARSE_TENSOR features + } + + +A significant portion of Twitter data is hydrated +and then temporarily stored on HDFS as DataRecords. +The files are compressed (.gz or .lzo) partitions of data records. +These form supervised datasets. Each sample captures the relationship +between input and output (cause and effect). +To create your own dataset, please see https://github.com/twitter/elephant-bird. + +The default ``DataRecordTrainer.[train,evaluate,learn]()`` reads these datarecords. +The data is a read from multiple ``part-*.[compression]`` files. +The default behavior of ``DataRecordTrainer`` is to read sparse features from ``DataRecords``. +This is a legacy default piping format at Twitter. +The ``DataRecordTrainer`` is flexible enough for research and yet simple enough +for a new beginner ML practioner. + +By means of the feature string to key hashing function, +the ``[train,eval]_feature_config`` constructor arguments +control which features can be used as sample labels, sample weights, +or sample features. +Samples ids, and feature keys, feature values and feature weights +can be skipped, included, excluded or used as labels, weights, or features. +This allows you to easily define and control sparse distributions of +named features. + +Yet sparse data is difficult to work with. We are currently working to +optimize the sparse operations due to inefficiencies in the gradient descent +and parameter update processes. There are efforts underway +to minimize the footprint of sparse data as it is inefficient to process. +CPUs and GPUs much prefer dense tensor data. +""" + +import datetime + +import tensorflow.compat.v1 as tf +from twitter.deepbird.io.dal import dal_to_hdfs_path, is_dal_path +import twml +from twml.trainers import Trainer +from twml.contrib.feature_importances.feature_importances import ( + compute_feature_importances, + TREE, + write_feature_importances_to_hdfs, + write_feature_importances_to_ml_dash) +from absl import logging + + +class DataRecordTrainer(Trainer): # pylint: disable=abstract-method + """ + The ``DataRecordTrainer`` implementation is intended to satisfy the most common use cases + at Twitter where only the build_graph methods needs to be overridden. + For this reason, ``Trainer.[train,eval]_input_fn`` methods + assume a DataRecord dataset partitioned into part files stored in compressed (e.g. gzip) format. + + For use-cases that differ from this common Twitter use-case, + further Trainer methods can be overridden. + If that still doesn't provide enough flexibility, the user can always + use the tf.estimator.Esimator or tf.session.run directly. + """ + + def __init__( + self, name, params, + build_graph_fn, + feature_config=None, + **kwargs): + """ + The DataRecordTrainer constructor builds a + ``tf.estimator.Estimator`` and stores it in self.estimator. + For this reason, DataRecordTrainer accepts the same Estimator constructor arguments. + It also accepts additional arguments to facilitate metric evaluation and multi-phase training + (init_from_dir, init_map). + + Args: + parent arguments: + See the `Trainer constructor <#twml.trainers.Trainer.__init__>`_ documentation + for a full list of arguments accepted by the parent class. + name, params, build_graph_fn (and other parent class args): + see documentation for twml.Trainer doc. + feature_config: + An object of type FeatureConfig describing what features to decode. + Defaults to None. But it is needed in the following cases: + - `get_train_input_fn()` / `get_eval_input_fn()` is called without a `parse_fn` + - `learn()`, `train()`, `eval()`, `calibrate()` are called without providing `*input_fn`. + + **kwargs: + further kwargs can be specified and passed to the Estimator constructor. + """ + + # NOTE: DO NOT MODIFY `params` BEFORE THIS CALL. + super(DataRecordTrainer, self).__init__( + name=name, params=params, build_graph_fn=build_graph_fn, **kwargs) + + self._feature_config = feature_config + + # date range parameters common to both training and evaluation data: + hour_resolution = self.params.get("hour_resolution", 1) + data_threads = self.params.get("data_threads", 4) + datetime_format = self.params.get("datetime_format", "%Y/%m/%d") + + # retrieve the desired training dataset files + self._train_files = self.build_files_list( + files_list_path=self.params.get("train_files_list", None), + data_dir=self.params.get("train_data_dir", None), + start_datetime=self.params.get("train_start_datetime", None), + end_datetime=self.params.get("train_end_datetime", None), + datetime_format=datetime_format, data_threads=data_threads, + hour_resolution=hour_resolution, maybe_save=self.is_chief(), + overwrite=self.params.get("train_overwrite_files_list", False), + ) + + # retrieve the desired evaluation dataset files + eval_name = self.params.get("eval_name", None) + + if eval_name == "train": + self._eval_files = self._train_files + else: + self._eval_files = self.build_files_list( + files_list_path=self.params.get("eval_files_list", None), + data_dir=self.params.get("eval_data_dir", None), + start_datetime=self.params.get("eval_start_datetime", None), + end_datetime=self.params.get("eval_end_datetime", None), + datetime_format=datetime_format, data_threads=data_threads, + hour_resolution=hour_resolution, maybe_save=self.is_chief(), + overwrite=self.params.get("eval_overwrite_files_list", False), + ) + + if not self.params.get("allow_train_eval_overlap"): + # if there is overlap between train and eval, error out! + if self._train_files and self._eval_files: + overlap_files = set(self._train_files) & set(self._eval_files) + else: + overlap_files = set() + if overlap_files: + raise ValueError("There is an overlap between train and eval files:\n %s" % + (overlap_files)) + + @staticmethod + def build_hdfs_files_list( + files_list_path, data_dir, + start_datetime, end_datetime, datetime_format, + data_threads, hour_resolution, maybe_save, overwrite): + if files_list_path: + files_list_path = twml.util.preprocess_path(files_list_path) + + if isinstance(start_datetime, datetime.datetime): + start_datetime = start_datetime.strftime(datetime_format) + if isinstance(end_datetime, datetime.datetime): + end_datetime = end_datetime.strftime(datetime_format) + + list_files_by_datetime_args = { + "base_path": data_dir, + "start_datetime": start_datetime, + "end_datetime": end_datetime, + "datetime_prefix_format": datetime_format, + "extension": "lzo", + "parallelism": data_threads, + "hour_resolution": hour_resolution, + "sort": True, + } + + # no cache of data file paths, just get the list by scraping the directory + if not files_list_path or not tf.io.gfile.exists(files_list_path): + # twml.util.list_files_by_datetime returns None if data_dir is None. + # twml.util.list_files_by_datetime passes through data_dir if data_dir is a list + files_list = twml.util.list_files_by_datetime(**list_files_by_datetime_args) + else: + # the cached data file paths file exists. + files_info = twml.util.read_file(files_list_path, decode="json") + # use the cached list if data params match current params, + # or if current params are None + # Not including None checks for datetime_format and hour_resolution, + # since those are shared between eval and training. + if (all(param is None for param in [data_dir, start_datetime, end_datetime]) or + (files_info["data_dir"] == data_dir and + files_info["start_datetime"] == start_datetime and + files_info["end_datetime"] == end_datetime and + files_info["datetime_format"] == datetime_format and + files_info["hour_resolution"] == hour_resolution)): + files_list = files_info["files"] + elif overwrite: + # current params are not none and don't match saved params + # `overwrite` indicates we should thus update the list + files_list = twml.util.list_files_by_datetime(**list_files_by_datetime_args) + else: + # dont update the cached list + raise ValueError("Information in files_list is inconsistent with provided args.\n" + "Did you intend to overwrite files_list using " + "--train.overwrite_files_list or --eval.overwrite_files_list?\n" + "If you instead want to use the paths in files_list, ensure that " + "data_dir, start_datetime, and end_datetime are None.") + + if maybe_save and files_list_path and (overwrite or not tf.io.gfile.exists(files_list_path)): + save_dict = {} + save_dict["files"] = files_list + save_dict["data_dir"] = data_dir + save_dict["start_datetime"] = start_datetime + save_dict["end_datetime"] = end_datetime + save_dict["datetime_format"] = datetime_format + save_dict["hour_resolution"] = hour_resolution + twml.util.write_file(files_list_path, save_dict, encode="json") + + return files_list + + @staticmethod + def build_files_list(files_list_path, data_dir, + start_datetime, end_datetime, datetime_format, + data_threads, hour_resolution, maybe_save, overwrite): + ''' + When specifying DAL datasets, only data_dir, start_dateime, and end_datetime + should be given with the format: + + dal://{cluster}/{role}/{dataset_name}/{env} + + ''' + if not data_dir or not is_dal_path(data_dir): + logging.warn(f"Please consider specifying a dal:// dataset rather than passing a physical hdfs path.") + return DataRecordTrainer.build_hdfs_files_list( + files_list_path, data_dir, + start_datetime, end_datetime, datetime_format, + data_threads, hour_resolution, maybe_save, overwrite) + + del datetime_format + del data_threads + del hour_resolution + del maybe_save + del overwrite + + return dal_to_hdfs_path( + path=data_dir, + start_datetime=start_datetime, + end_datetime=end_datetime, + ) + + @property + def train_files(self): + return self._train_files + + @property + def eval_files(self): + return self._eval_files + + @staticmethod + def add_parser_arguments(): + """ + Add common commandline args to parse for the Trainer class. + Typically, the user calls this function and then parses cmd-line arguments + into an argparse.Namespace object which is then passed to the Trainer constructor + via the params argument. + + See the `Trainer code <_modules/twml/trainers/trainer.html#Trainer.add_parser_arguments>`_ + and `DataRecordTrainer code + <_modules/twml/trainers/trainer.html#DataRecordTrainer.add_parser_arguments>`_ + for a list and description of all cmd-line arguments. + + Args: + learning_rate_decay: + Defaults to False. When True, parses learning rate decay arguments. + + Returns: + argparse.ArgumentParser instance with some useful args already added. + """ + parser = super(DataRecordTrainer, DataRecordTrainer).add_parser_arguments() + parser.add_argument( + "--train.files_list", "--train_files_list", type=str, default=None, + dest="train_files_list", + help="Path for a json file storing information on training data.\n" + "Specifically, the file at files_list should contain the dataset parameters " + "for constructing the list of data files, and the list of data file paths.\n" + "If the json file does not exist, other args are used to construct the " + "training files list, and that list will be saved to the indicated json file.\n" + "If the json file does exist, and current args are consistent with " + "saved args, or are all None, then the saved files list will be used.\n" + "If current args are not consistent with the saved args, then error out " + "if train_overwrite_files_list==False, else overwrite files_list with " + "a newly constructed list.") + parser.add_argument( + "--train.overwrite_files_list", "--train_overwrite_files_list", action="store_true", default=False, + dest="train_overwrite_files_list", + help="When the --train.files_list param is used, indicates whether to " + "overwrite the existing --train.files_list when there are differences " + "between the current and saved dataset args. Default (False) is to " + "error out if files_list exists and differs from current params.") + parser.add_argument( + "--train.data_dir", "--train_data_dir", type=str, default=None, + dest="train_data_dir", + help="Path to the training data directory." + "Supports local, dal://{cluster}-{region}/{role}/{dataset_name}/{environment}, " + "and HDFS (hdfs://default/ ) paths.") + parser.add_argument( + "--train.start_date", "--train_start_datetime", + type=str, default=None, + dest="train_start_datetime", + help="Starting date for training inside the train data dir." + "The start datetime is inclusive." + "e.g. 2019/01/15") + parser.add_argument( + "--train.end_date", "--train_end_datetime", type=str, default=None, + dest="train_end_datetime", + help="Ending date for training inside the train data dir." + "The end datetime is inclusive." + "e.g. 2019/01/15") + parser.add_argument( + "--eval.files_list", "--eval_files_list", type=str, default=None, + dest="eval_files_list", + help="Path for a json file storing information on evaluation data.\n" + "Specifically, the file at files_list should contain the dataset parameters " + "for constructing the list of data files, and the list of data file paths.\n" + "If the json file does not exist, other args are used to construct the " + "evaluation files list, and that list will be saved to the indicated json file.\n" + "If the json file does exist, and current args are consistent with " + "saved args, or are all None, then the saved files list will be used.\n" + "If current args are not consistent with the saved args, then error out " + "if eval_overwrite_files_list==False, else overwrite files_list with " + "a newly constructed list.") + parser.add_argument( + "--eval.overwrite_files_list", "--eval_overwrite_files_list", action="store_true", default=False, + dest="eval_overwrite_files_list", + help="When the --eval.files_list param is used, indicates whether to " + "overwrite the existing --eval.files_list when there are differences " + "between the current and saved dataset args. Default (False) is to " + "error out if files_list exists and differs from current params.") + parser.add_argument( + "--eval.data_dir", "--eval_data_dir", type=str, default=None, + dest="eval_data_dir", + help="Path to the cross-validation data directory." + "Supports local, dal://{cluster}-{region}/{role}/{dataset_name}/{environment}, " + "and HDFS (hdfs://default/ ) paths.") + parser.add_argument( + "--eval.start_date", "--eval_start_datetime", + type=str, default=None, + dest="eval_start_datetime", + help="Starting date for evaluating inside the eval data dir." + "The start datetime is inclusive." + "e.g. 2019/01/15") + parser.add_argument( + "--eval.end_date", "--eval_end_datetime", type=str, default=None, + dest="eval_end_datetime", + help="Ending date for evaluating inside the eval data dir." + "The end datetime is inclusive." + "e.g. 2019/01/15") + parser.add_argument( + "--datetime_format", type=str, default="%Y/%m/%d", + help="Date format for training and evaluation datasets." + "Has to be a format that is understood by python datetime." + "e.g. %%Y/%%m/%%d for 2019/01/15." + "Used only if {train/eval}.{start/end}_date are provided.") + parser.add_argument( + "--hour_resolution", type=int, default=None, + help="Specify the hourly resolution of the stored data.") + parser.add_argument( + "--data_spec", type=str, required=True, + help="Path to data specification JSON file. This file is used to decode DataRecords") + parser.add_argument( + "--train.keep_rate", "--train_keep_rate", type=float, default=None, + dest="train_keep_rate", + help="A float value in (0.0, 1.0] that indicates to drop records according to the Bernoulli \ + distribution with p = 1 - keep_rate.") + parser.add_argument( + "--eval.keep_rate", "--eval_keep_rate", type=float, default=None, + dest="eval_keep_rate", + help="A float value in (0.0, 1.0] that indicates to drop records according to the Bernoulli \ + distribution with p = 1 - keep_rate.") + parser.add_argument( + "--train.parts_downsampling_rate", "--train_parts_downsampling_rate", + dest="train_parts_downsampling_rate", + type=float, default=None, + help="A float value in (0.0, 1.0] that indicates the factor by which to downsample part \ + files. For example, a value of 0.2 means only 20 percent of part files become part of the \ + dataset.") + parser.add_argument( + "--eval.parts_downsampling_rate", "--eval_parts_downsampling_rate", + dest="eval_parts_downsampling_rate", + type=float, default=None, + help="A float value in (0.0, 1.0] that indicates the factor by which to downsample part \ + files. For example, a value of 0.2 means only 20 percent of part files become part of the \ + dataset.") + parser.add_argument( + "--allow_train_eval_overlap", + dest="allow_train_eval_overlap", + action="store_true", + help="Allow overlap between train and eval datasets." + ) + parser.add_argument( + "--eval_name", type=str, default=None, + help="String denoting what we want to name the eval. If this is `train`, then we eval on \ + the training dataset." + ) + return parser + + def contrib_run_feature_importances(self, feature_importances_parse_fn=None, write_to_hdfs=True, extra_groups=None, datarecord_filter_fn=None, datarecord_filter_run_name=None): + """Compute feature importances on a trained model (this is a contrib feature) + Args: + feature_importances_parse_fn (fn): The same parse_fn that we use for training/evaluation. + Defaults to feature_config.get_parse_fn() + write_to_hdfs (bool): Setting this to True writes the feature importance metrics to HDFS + extra_groups (dict>): A dictionary mapping the name of extra feature groups to the list of + the names of the features in the group + datarecord_filter_fn (function): a function takes a single data sample in com.twitter.ml.api.ttypes.DataRecord format + and return a boolean value, to indicate if this data record should be kept in feature importance module or not. + """ + logging.info("Computing feature importance") + algorithm = self._params.feature_importance_algorithm + + kwargs = {} + if algorithm == TREE: + kwargs["split_feature_group_on_period"] = self._params.split_feature_group_on_period + kwargs["stopping_metric"] = self._params.feature_importance_metric + kwargs["sensitivity"] = self._params.feature_importance_sensitivity + kwargs["dont_build_tree"] = self._params.dont_build_tree + kwargs["extra_groups"] = extra_groups + if self._params.feature_importance_is_metric_larger_the_better: + # The user has specified that the stopping metric is one where larger values are better (e.g. ROC_AUC) + kwargs["is_metric_larger_the_better"] = True + elif self._params.feature_importance_is_metric_smaller_the_better: + # The user has specified that the stopping metric is one where smaller values are better (e.g. LOSS) + kwargs["is_metric_larger_the_better"] = False + else: + # The user has not specified which direction is better for the stopping metric + kwargs["is_metric_larger_the_better"] = None + logging.info("Using the tree algorithm with kwargs {}".format(kwargs)) + + feature_importances = compute_feature_importances( + trainer=self, + data_dir=self._params.get('feature_importance_data_dir'), + feature_config=self._feature_config, + algorithm=algorithm, + record_count=self._params.feature_importance_example_count, + parse_fn=feature_importances_parse_fn, + datarecord_filter_fn=datarecord_filter_fn, + **kwargs) + + if not feature_importances: + logging.info("Feature importances returned None") + else: + if write_to_hdfs: + logging.info("Writing feature importance to HDFS") + write_feature_importances_to_hdfs( + trainer=self, + feature_importances=feature_importances, + output_path=datarecord_filter_run_name, + metric=self._params.get('feature_importance_metric')) + else: + logging.info("Not writing feature importance to HDFS") + + logging.info("Writing feature importance to ML Metastore") + write_feature_importances_to_ml_dash( + trainer=self, feature_importances=feature_importances) + return feature_importances + + def export_model(self, serving_input_receiver_fn=None, + export_output_fn=None, + export_dir=None, checkpoint_path=None, + feature_spec=None): + """ + Export the model for prediction. Typically, the exported model + will later be run in production servers. This method is called + by the user to export the PREDICT graph to disk. + + Internally, this method calls `tf.estimator.Estimator.export_savedmodel + `_. + + Args: + serving_input_receiver_fn (Function): + function preparing the model for inference requests. + If not set; defaults to the the serving input receiver fn set by the FeatureConfig. + export_output_fn (Function): + Function to export the graph_output (output of build_graph) for + prediction. Takes a graph_output dict as sole argument and returns + the export_output_fns dict. + Defaults to ``twml.export_output_fns.batch_prediction_continuous_output_fn``. + export_dir: + directory to export a SavedModel for prediction servers. + Defaults to ``[save_dir]/exported_models``. + checkpoint_path: + the checkpoint path to export. If None (the default), the most recent checkpoint + found within the model directory ``save_dir`` is chosen. + + Returns: + The export directory where the PREDICT graph is saved. + """ + if serving_input_receiver_fn is None: + if self._feature_config is None: + raise ValueError("`feature_config` was not passed to `DataRecordTrainer`") + serving_input_receiver_fn = self._feature_config.get_serving_input_receiver_fn() + + if feature_spec is None: + if self._feature_config is None: + raise ValueError("feature_spec can not be inferred." + "Please pass feature_spec=feature_config.get_feature_spec() to the trainer.export_model method") + else: + feature_spec = self._feature_config.get_feature_spec() + + if isinstance(serving_input_receiver_fn, twml.feature_config.FeatureConfig): + raise ValueError("Cannot pass FeatureConfig as a parameter to serving_input_receiver_fn") + elif not callable(serving_input_receiver_fn): + raise ValueError("Expecting Function for serving_input_receiver_fn") + + if export_output_fn is None: + export_output_fn = twml.export_output_fns.batch_prediction_continuous_output_fn + + return super(DataRecordTrainer, self).export_model( + export_dir=export_dir, + serving_input_receiver_fn=serving_input_receiver_fn, + checkpoint_path=checkpoint_path, + export_output_fn=export_output_fn, + feature_spec=feature_spec, + ) + + def get_train_input_fn( + self, parse_fn=None, repeat=None, shuffle=True, interleave=True, shuffle_files=None, + initializable=False, log_tf_data_summaries=False, **kwargs): + """ + This method is used to create input function used by estimator.train(). + + Args: + parse_fn: + Function to parse a data record into a set of features. + Defaults to the parser returned by the FeatureConfig selected + repeat (optional): + Specifies if the dataset is to be repeated. Defaults to `params.train_steps > 0`. + This ensures the training is run for atleast `params.train_steps`. + Toggling this to `False` results in training finishing when one of the following happens: + - The entire dataset has been trained upon once. + - `params.train_steps` has been reached. + shuffle (optional): + Specifies if the files and records in the files need to be shuffled. + When `True`, files are shuffled, and records of each files are shuffled. + When `False`, files are read in alpha-numerical order. Also when `False` + the dataset is sharded among workers for Hogwild and distributed training + if no sharding configuration is provided in `params.train_dataset_shards`. + Defaults to `True`. + interleave (optional): + Specifies if records from multiple files need to be interleaved in parallel. + Defaults to `True`. + shuffle_files (optional): + Shuffle the list of files. Defaults to 'Shuffle' if not provided. + initializable (optional): + A boolean indicator. When the parsing function depends on some resource, e.g. a HashTable or + a Tensor, i.e. it's an initializable iterator, set it to True. Otherwise, default value + (false) is used for most plain iterators. + log_tf_data_summaries (optional): + A boolean indicator denoting whether to add a `tf.data.experimental.StatsAggregator` to the + tf.data pipeline. This adds summaries of pipeline utilization and buffer sizes to the output + events files. This requires that `initializable` is `True` above. + + Returns: + An input_fn that can be consumed by `estimator.train()`. + """ + if parse_fn is None: + if self._feature_config is None: + raise ValueError("`feature_config` was not passed to `DataRecordTrainer`") + parse_fn = self._feature_config.get_parse_fn() + + if not callable(parse_fn): + raise ValueError("Expecting parse_fn to be a function.") + + if log_tf_data_summaries and not initializable: + raise ValueError("Require `initializable` if `log_tf_data_summaries`.") + + if repeat is None: + repeat = self.params.train_steps > 0 or self.params.get('distributed', False) + + if not shuffle and self.num_workers > 1 and self.params.train_dataset_shards is None: + num_shards = self.num_workers + shard_index = self.worker_index + else: + num_shards = self.params.train_dataset_shards + shard_index = self.params.train_dataset_shard_index + + return lambda: twml.input_fns.default_input_fn( + files=self._train_files, + batch_size=self.params.train_batch_size, + parse_fn=parse_fn, + num_threads=self.params.num_threads, + repeat=repeat, + keep_rate=self.params.train_keep_rate, + parts_downsampling_rate=self.params.train_parts_downsampling_rate, + shards=num_shards, + shard_index=shard_index, + shuffle=shuffle, + shuffle_files=(shuffle if shuffle_files is None else shuffle_files), + interleave=interleave, + initializable=initializable, + log_tf_data_summaries=log_tf_data_summaries, + **kwargs) + + def get_eval_input_fn( + self, parse_fn=None, repeat=None, + shuffle=True, interleave=True, + shuffle_files=None, initializable=False, log_tf_data_summaries=False, **kwargs): + """ + This method is used to create input function used by estimator.eval(). + + Args: + parse_fn: + Function to parse a data record into a set of features. + Defaults to twml.parsers.get_sparse_parse_fn(feature_config). + repeat (optional): + Specifies if the dataset is to be repeated. Defaults to `params.eval_steps > 0`. + This ensures the evaluation is run for atleast `params.eval_steps`. + Toggling this to `False` results in evaluation finishing when one of the following happens: + - The entire dataset has been evaled upon once. + - `params.eval_steps` has been reached. + shuffle (optional): + Specifies if the files and records in the files need to be shuffled. + When `False`, files are read in alpha-numerical order. + When `True`, files are shuffled, and records of each files are shuffled. + Defaults to `True`. + interleave (optional): + Specifies if records from multiple files need to be interleaved in parallel. + Defaults to `True`. + shuffle_files (optional): + Shuffles the list of files. Defaults to 'Shuffle' if not provided. + initializable (optional): + A boolean indicator. When the parsing function depends on some resource, e.g. a HashTable or + a Tensor, i.e. it's an initializable iterator, set it to True. Otherwise, default value + (false) is used for most plain iterators. + log_tf_data_summaries (optional): + A boolean indicator denoting whether to add a `tf.data.experimental.StatsAggregator` to the + tf.data pipeline. This adds summaries of pipeline utilization and buffer sizes to the output + events files. This requires that `initializable` is `True` above. + + Returns: + An input_fn that can be consumed by `estimator.eval()`. + """ + if parse_fn is None: + if self._feature_config is None: + raise ValueError("`feature_config` was not passed to `DataRecordTrainer`") + parse_fn = self._feature_config.get_parse_fn() + + if not self._eval_files: + raise ValueError("`eval_files` was not present in `params` passed to `DataRecordTrainer`") + + if not callable(parse_fn): + raise ValueError("Expecting parse_fn to be a function.") + + if log_tf_data_summaries and not initializable: + raise ValueError("Require `initializable` if `log_tf_data_summaries`.") + + if repeat is None: + repeat = self.params.eval_steps > 0 + + return lambda: twml.input_fns.default_input_fn( + files=self._eval_files, + batch_size=self.params.eval_batch_size, + parse_fn=parse_fn, + num_threads=self.params.num_threads, + repeat=repeat, + keep_rate=self.params.eval_keep_rate, + parts_downsampling_rate=self.params.eval_parts_downsampling_rate, + shuffle=shuffle, + shuffle_files=(shuffle if shuffle_files is None else shuffle_files), + interleave=interleave, + initializable=initializable, + log_tf_data_summaries=log_tf_data_summaries, + **kwargs + ) + + def _assert_train_files(self): + if not self._train_files: + raise ValueError("train.data_dir was not set in params passed to DataRecordTrainer.") + + def _assert_eval_files(self): + if not self._eval_files: + raise ValueError("eval.data_dir was not set in params passed to DataRecordTrainer.") + + def train(self, input_fn=None, steps=None, hooks=None): + """ + Makes input functions optional. input_fn defaults to self.get_train_input_fn(). + See Trainer for more detailed documentation documentation. + """ + if input_fn is None: + self._assert_train_files() + input_fn = input_fn if input_fn else self.get_train_input_fn() + super(DataRecordTrainer, self).train(input_fn=input_fn, steps=steps, hooks=hooks) + + def evaluate(self, input_fn=None, steps=None, hooks=None, name=None): + """ + Makes input functions optional. input_fn defaults to self.get_eval_input_fn(). + See Trainer for more detailed documentation. + """ + if input_fn is None: + self._assert_eval_files() + input_fn = input_fn if input_fn else self.get_eval_input_fn(repeat=False) + return super(DataRecordTrainer, self).evaluate( + input_fn=input_fn, + steps=steps, + hooks=hooks, + name=name + ) + + def learn(self, train_input_fn=None, eval_input_fn=None, **kwargs): + """ + Overrides ``Trainer.learn`` to make ``input_fn`` functions optional. + Respectively, ``train_input_fn`` and ``eval_input_fn`` default to + ``self.train_input_fn`` and ``self.eval_input_fn``. + See ``Trainer.learn`` for more detailed documentation. + """ + if train_input_fn is None: + self._assert_train_files() + if eval_input_fn is None: + self._assert_eval_files() + train_input_fn = train_input_fn if train_input_fn else self.get_train_input_fn() + eval_input_fn = eval_input_fn if eval_input_fn else self.get_eval_input_fn() + + super(DataRecordTrainer, self).learn( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + **kwargs + ) + + def train_and_evaluate(self, + train_input_fn=None, eval_input_fn=None, + **kwargs): + """ + Overrides ``Trainer.train_and_evaluate`` to make ``input_fn`` functions optional. + Respectively, ``train_input_fn`` and ``eval_input_fn`` default to + ``self.train_input_fn`` and ``self.eval_input_fn``. + See ``Trainer.train_and_evaluate`` for detailed documentation. + """ + if train_input_fn is None: + self._assert_train_files() + if eval_input_fn is None: + self._assert_eval_files() + train_input_fn = train_input_fn if train_input_fn else self.get_train_input_fn() + eval_input_fn = eval_input_fn if eval_input_fn else self.get_eval_input_fn() + + super(DataRecordTrainer, self).train_and_evaluate( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + **kwargs + ) + + def _model_fn(self, features, labels, mode, params, config=None): + """ + Overrides the _model_fn to correct for the features shape of the sparse features + extracted with the contrib.FeatureConfig + """ + if isinstance(self._feature_config, twml.contrib.feature_config.FeatureConfig): + # Fix the shape of the features. The features dictionary will be modified to + # contain the shape changes. + twml.util.fix_shape_sparse(features, self._feature_config) + return super(DataRecordTrainer, self)._model_fn( + features=features, + labels=labels, + mode=mode, + params=params, + config=config + ) + + def calibrate(self, + calibrator, + input_fn=None, + steps=None, + save_calibrator=True, + hooks=None): + """ + Makes input functions optional. input_fn defaults to self.train_input_fn. + See Trainer for more detailed documentation. + """ + if input_fn is None: + self._assert_train_files() + input_fn = input_fn if input_fn else self.get_train_input_fn() + super(DataRecordTrainer, self).calibrate(calibrator=calibrator, + input_fn=input_fn, + steps=steps, + save_calibrator=save_calibrator, + hooks=hooks) + + def save_checkpoints_and_export_model(self, + serving_input_receiver_fn, + export_output_fn=None, + export_dir=None, + checkpoint_path=None, + input_fn=None): + """ + Exports saved module after saving checkpoint to save_dir. + Please note that to use this method, you need to assign a loss to the output + of the build_graph (for the train mode). + See export_model for more detailed information. + """ + self.train(input_fn=input_fn, steps=1) + self.export_model(serving_input_receiver_fn, export_output_fn, export_dir, checkpoint_path) + + def save_checkpoints_and_evaluate(self, + input_fn=None, + steps=None, + hooks=None, + name=None): + """ + Evaluates model after saving checkpoint to save_dir. + Please note that to use this method, you need to assign a loss to the output + of the build_graph (for the train mode). + See evaluate for more detailed information. + """ + self.train(input_fn=input_fn, steps=1) + self.evaluate(input_fn, steps, hooks, name) diff --git a/twml/twml/trainers/trainer.py b/twml/twml/trainers/trainer.py new file mode 100644 index 000000000..e51b4e0fd --- /dev/null +++ b/twml/twml/trainers/trainer.py @@ -0,0 +1,1777 @@ +# pylint: disable=too-many-lines +""" +``twml.trainers.Trainer`` is a wrapper around `tf.estimator.Estimator +`_ +to expose an easier to use API by +hiding rarely used config knobs and supplying default values. + +The `Trainer` facilitates multi-phase training commonly used at Twitter: e.g. +MDL calibration -> MLP training -> Isotonic calibration. +The `Trainer` also facilitates hyperparameters tuning, +with its simple `add_parser_arguments()` method. + +Learning rate decay functions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Please note that we have four learning rate decay functions to choose from. +Additionally, each trainer can only take one learning rate decay function and its parameters. +If that is not the case, it will throw an error. +Also, please note that the learning rate decay is a positional argument and should be placed as +the last argument to the trainer, as you can see in the example above. +The four learning decays options are: + +1. inverse_learning_rate_decay: + + The function returns the decayed learning rate. It is computed as: + + :: + + decayed_learning_rate = learning_rate / (1 + decay_rate * global_step /decay_step) + final_decayed_learning_rate = max(decayed_learning_rate, min_learning_rate) + + +2. polynomial_learning_rate_decay: + + The function returns the decayed learning rate. It is computed as: + + :: + + global_step = min(global_step, decay_steps) + decayed_learning_rate = (learning_rate - end_learning_rate) * + (1 - global_step / decay_steps) ^ (power) + + end_learning_rate + + +3. piecewise_constant_learning_rate_decay: + + Piecewise constant from boundaries and interval values. + + Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5 for + the next 10000 steps, and 0.1 for any additional steps. + + :: + + global_step = tf.Variable(0, trainable=False) + boundaries = [100000, 110000] + values = [1.0, 0.5, 0.1] + learning_rate = tf.train.piecewise_constant(global_step, boundaries, values) + +4. exponential_learning_rate_decay: + + The function returns the decayed learning rate. It is computed as: + + :: + + decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps) + +""" + +import datetime +import functools +import math +from operator import itemgetter +import os +import pprint as pp +import random +from string import Template +import subprocess +import sys +import time +from threading import Thread + +from twitter.common.metrics import AtomicGauge +from twitter.deepbird.stats_server import utils as stats_server_utils +from twitter.deepbird.stats_server.stats_exporter import StatsExporter +from twitter.ml.common import metrics +from twitter.ml.common.kubernetes import kubectl_delete_by_name, Resource +from twitter.ml.twml.status import get_distributed_training_job_status, TrainingJobStatus + +from absl import logging +from twml.optimizers import LazyAdamOptimizer, optimize_loss, OPTIMIZER_SUMMARIES +from twml.contrib.optimizers import DeepGradientCompressionOptimizer +from twml.tracking import ExperimentTracker +from twml.util import (delete_file_or_dir, + get_distributed_training_job_path, + sanitize_hdfs_path) +try: + from urllib import quote as encode_url +except ImportError: + from urllib.parse import quote as encode_url +import tensorflow.compat.v1 as tf +import tensorflow +import tensorflow_hub as hub + +import twitter.ml.twml.kubernetes.status as k8s_status +import twml +import twml.export_output_fns +import twml.learning_rate_decay +import twml.metrics + + +_CLUSTER_TEMPLATE = Template('''{ + "cluster": { + "ps": [$PS], + "chief": [$CHIEF], + "worker": [$WORKER] + }, + "task": {"type": "$TYPE", "index": $INDEX} +} +''') + + +def init_from_checkpoint(init_dir, init_map): + """ + Wrapper around tf.train.init_from_checkpoint + """ + if init_dir: + init_dir = sanitize_hdfs_path(init_dir) + tf.train.init_from_checkpoint(init_dir, init_map) + + +class Trainer(object): + """ + This class wraps ``tf.estimator.Estimator`` to make construction, saving, and loading easier. + Supports multi-phase training (for example, use a Trainer for MDL calibration, then + another for training the rest of the model, then another for isotonic calibration). + The Trainer also implements a training and evaluation loop via the ``learn()`` method. + Each Trainer is associated to a fixed set of hyper parameters (params), and a single model + specified by ``build_graph``. Given these constraints, a single Trainer can be called + multiple times for training and evaluation over multiple epochs. + + However, if you intend to try different sets of hyper-parameters, we recommend you instantiate + a different Trainer for each such experiment. That way, each experiment can be tracked + in a different ``save_dir``. Indeed, after calling ``learn``, a Trainer's save_dir will contain + checkpoints of the model (its graph, and variables), and the history of metrics (for example, + evaluation accuracy at each epoch), and other store observations like the average time per step. + The latter metrics can be viewed by pointing + TensorBoard to the save_dir and accessing TensorBoard via your browser. + """ + + def __init__(self, name, params, build_graph_fn, + metric_fn=None, + optimize_loss_fn=None, + run_config=None, + save_dir=None, + init_from_dir=None, + init_map=None, + warm_start_from=None, + profiler_steps=None, + **kwargs): + """ + + Args: + name (String): + string name of this estimator; used as scope names for variables and tensors. + params (HParams, Namespace, or Dict): + hyper-parameters to be passed to Estimator constructor. + Must include params.train_batch_size and params.eval_batch_size. + Note that params is passed to twml.util.convert_to_hparams() to produce an HParams. + build_graph_fn: + A function for building tensorflow graphs. + This matches TensorFlow Estimator's model_fn signature. + For example, + + .. code-block:: python + + def build_graph(features, label, mode, params, config=None): + # Implements a simple binary logistic regression model + sparse_tf = twml.util.convert_to_sparse(features, params.input_size_bits) + + logits = twml.layers.full_sparse(sparse_tf, 1 << params.input_size_bits, 1) + + if mode == 'infer': + loss = None + else: + loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=label, logits=logits) + loss = twml.util.weighted_average(loss, features['weights']) + + output = tf.nn.sigmoid(logits) + + return {'output': output, 'loss': loss} + + Args: + features (dict of Tensor keyed by a string name): + input tensors. + mode (tf.estimator.ModeKeys / String): + one of 'train', 'eval', 'infer'. + label (Tensor): + if in ``mode == 'train'`` mode, these contain the corresponding labels for input. + params (HParams): + hyper parameters that control how to build a graph. + config: + the RunConfig object passed to Estimator constructor. + + This function is expected to return a dictionary containing the following keys: + + * 'output': a node representing model output; required. + * 'loss': (required) a loss node used for optimization; required for training and + evaluation. + * 'train_op': (optional) an operation that minimizes the loss (as output by + `tf.train.Optimizer.minimize`). If train_op is specified, train_op is used + for optimization as opposed to loss. Loss is always logged to tensorboard. + + Notes: + + * any tf.summary written inside build graph are logged to tensorboard during training. + * the ``build_graph_fn`` is called once or twice per epoch (once per training, + once per evaluation). All data loading (and preprocessing) logic not required + for serving should be in the ``input_fn`` passed to ``learn``, ``train``, + ``evalulate``, etc. + + optimize_loss_fn: + Defaults to Trainer.get_train_op. A function that takes params and loss as arguments + and returns a training op. The training op is used to update parameters (that is, to learn). + metric_fn: + A function that returns the eval_metric_ops dict given graph_output, labels and weights. + Defaults to None. + Use ``twml.metrics.get_binary_class_metric_fn()`` to return a ``metric_fn`` + which implements many binary classification metrics. + run_config (RunConfig): + optional configuration to be passed to Estimator constructor. Defaults to None. + save_dir (String): + optional directory where to save model checkpoints, + tensorboard event files and trained parameters. + Overwrites and defaults to run_config.model_dir. + init_from_dir (String): + optional directory to load weights from. + if set to None (the default), do not init from any directory. + init_map (map from String to String): + Must be specified if init_from_dir is specified. + Defines which scopes and variables to load. + Keys are the variables and scopes to load from the directory. + Values are the destinations (in the current graph) to load into. + See tf.init_from_checkpoint for more information. + Note that the the trainer prepends name_scope of the form `name`/model/ to the name_scope + of any variable defined inside `build_graph_fn` and this should be taken into account when + defining the values. + warm_start_from: + Optional string filepath to a checkpoint to warm-start from, + or a tf.estimator.WarmStartSettings object to fully configure warm-starting. + If the string filepath is provided instead of a WarmStartSettings, + then all variables are warm-started, and it is assumed that + vocabularies and Tensor names are unchanged. + profiler_steps (Integer): + Defaults to None. If set defines the number of steps in the + `tf.train.ProfileHook `_. + Captures CPU/GPU profiling information every ``profiler_steps`` steps or seconds. + When executing ``learn``, ``train`` or ``predict`` methods, + with ``profiler_steps`` set to a number, + a ``timeline_X.json`` file is created in the save_dir. This file contains profiling data + storedin Chrome trace format. To view stored data, use the Chrome browser to follow + these steps: + + 1) Go to the page chrome://tracing. + 2) In the upper left corner, you will find Load button. + 3) Press it and load our JSON file, which can be found in the ``save_dir`` + + *Warning*: This could create too many these json files which can be a potential problem, + e.g. for HDFS there is normally quota forfile count, so use with caution. + + Note: this argument is ignored when a non-None ``hooks`` argument is pasesd to + ``train``, ``learn``, or ``predict`` methods. The hook can be added manually by passing + ``trainer.train(..., hooks=myhooks.extend(trainer.get_train_hooks()))``, for example. + """ + + if tensorflow.__version__ >= "2.0": + RuntimeError("Trainer not yet supported for Tensorflow >= 2.0") + + self._name = name + self._build_graph_fn = build_graph_fn + self._metric_fn = metric_fn + self._tensorboard_handle = None + self._current_estimator_spec = None # holds the current estimator spec + self._profiler_steps = profiler_steps + self._export_output_fn = None + self._is_early_stopping = False + + # NOTE: Sanitize all HDFS paths first. + save_dir = sanitize_hdfs_path(save_dir) + init_from_dir = sanitize_hdfs_path(init_from_dir) + + # warm_start_from can be of type tf.estimator.WarmStartSettings. + if isinstance(warm_start_from, str): + warm_start_from = sanitize_hdfs_path(warm_start_from) + + # convert to twitter.deepbird.hparam.hparam.HParams object + params = twml.util.convert_to_hparams(params) + + # keep a copy of the params because calling self._estimator.params creates a deepcopy + self._params = params + self.check_params() + + self._using_hogwild = True if os.environ.get('TWML_HOGWILD_PORTS') else False + # configure Hogwild (needs to be called before RunConfig is created) + self._hogwild_setup() + + if not run_config: + session_config = tf.ConfigProto() + # By default each process tries to allocate (almost) all of the memory. + # This option ensures the gpu memory grows dynamically instead. + session_config.gpu_options.allow_growth = True # pylint: disable=no-member + + if 'TWML_NUM_CPUS' in os.environ: + num_available_cpus = int(os.environ.get("TWML_MESOS_CPU", "8")) + if params.num_mkl_threads > 1: + os.environ["OMP_NUM_THREADS"] = str(params.num_mkl_threads) + os.environ["MKL_NUM_THREADS"] = str(params.num_mkl_threads) + session_config.inter_op_parallelism_threads = num_available_cpus // params.num_mkl_threads + session_config.intra_op_parallelism_threads = params.num_mkl_threads + + run_config = tf.estimator.RunConfig( + session_config=session_config, + keep_checkpoint_max=self._params.get('keep_checkpoint_max', 20), + log_step_count_steps=10000, + save_checkpoints_secs=self._params.get('save_checkpoints_secs', 600), + tf_random_seed=self._tf_random_seed()) + elif not isinstance(run_config, tf.estimator.RunConfig): + raise ValueError("Expecting run_config argument of type None or tf.estimator.RunConfig" + "Got %s instead." % type(run_config).__name__) + elif os.environ.get('TWML_HOGWILD_PORTS'): + raise ValueError("Custom RunConfig not supported with Hogwild") + + if run_config.model_dir is None and save_dir is None: + raise ValueError( + "Expecting either save_dir or run_config.model_dir to be specified. Got None for each.") + elif run_config.model_dir is None: + run_config = run_config.replace(model_dir=save_dir) + elif save_dir is None: + save_dir = run_config.model_dir + + self._save_dir = save_dir + self.experiment_tracker = ExperimentTracker(self._params, run_config, self._save_dir) + + # Check if should delete the tsd running this training job. In certain use case when + # there are other tf operations following trainer.train_and_evaluate (or trainer.learn), + # additional state files need to be specified to ensure those steps are executed after job restart. + kwargs['gke_state_files'] = kwargs.get('gke_state_files', ['_SUCCESS']) + self._maybe_del_tsd_exit(kwargs['gke_state_files']) + logging.info("Checkpoint and event files will be saved at save_dir=%s", save_dir) + self._optimize_loss_fn = self.get_train_op if optimize_loss_fn is None else optimize_loss_fn + + # overwrite the current save_dir + if self._params.get('overwrite_save_dir') and tf.io.gfile.exists(self._save_dir): + logging.info("Trainer overwriting existing save directory: %s (params.overwrite_save_dir)" + % self._save_dir) + # if distributed or hogwild: + if self._params.get('distributed', False): + # sleep for 30 seconds to allow each worker to get to this point. + time.sleep(30) + if run_config.is_chief: + logging.info("Chief deleting the save_dir now") + delete_file_or_dir(self._save_dir) + # sleep for 30 seconds to allow each worker to get to this point. + time.sleep(30) + else: + delete_file_or_dir(self._save_dir) + + # Exposing stats to a /vars.json endpoint that will be collected + # by the absorber + if self._params.get('stats_port'): + try: + stats_server_utils.start_stats_server(self._params.get('stats_port'), self._save_dir) + except Exception as err: + logging.error('Failed to start the stats server. Error: %s', str(err)) + + checkpoint = os.path.join(self._save_dir, 'checkpoint') + if tf.io.gfile.exists(checkpoint): + logging.info("The provided save_dir directory %s already exists." + " Training will be resumed." + % checkpoint) + + self._maybe_restore_checkpoint = lambda: init_from_checkpoint(init_from_dir, init_map) + + if init_from_dir is not None and init_map is None: + raise ValueError("Need to provide init_map when init_from_dir is provided.") + + if not tf.io.gfile.exists(self._save_dir): + # so tensorboard can point to a directory that exists + tf.io.gfile.mkdir(self._save_dir) + + self._estimator = tf.estimator.Estimator( + model_fn=self._model_fn, + params=self._params, # HParams + config=run_config, # RunConfig + warm_start_from=warm_start_from, + model_dir=self._save_dir, # By this point it is same as run_config.model_dir + ) + + # Log parameters that are used to construct trainer. This allows people to see default values. + logging.info("Trainer constructed using the following parameters: ") + pp_params = pp.pformat(self._params.values()) + logging.info(pp_params) + + # Start TensorBoard + if self._params.get('disable_tensorboard', False): + logging.info("Skipping launching TensorBoard [--disable_tensorboard is set]") + elif "tensorboard_port" in self._params.values() and self._params.tensorboard_port is not None: + self.start_tensorboard(self._params.tensorboard_port) + + # Export gauge that will track whether a model was exported + self.stats_exporter = StatsExporter("twml.trainer") + self.export_gauge = AtomicGauge('export_model') + self.stats_exporter.register_metrics(self.export_gauge) + + def _hogwild_setup(self): + """ + Setup the parameters required for hogwild. + """ + self._num_workers = self._params.get('num_workers') or 1 + logging.info("NUM_WORKERS: %d", self._num_workers) + if self._num_workers <= 1: + self._ports = None + return + + # a hogwild job is considered distributed + if 'distributed' in self._params: + self._params.set_hparam('distributed', True) + else: + self._params.add_hparam('distributed', True) + + ports = os.environ.get('TWML_HOGWILD_PORTS') + if ports: + self._ports = [int(port) for port in ports.strip().split(",")] + if (self._num_workers + 1!= len(self._ports)): + raise ValueError("Number of (workers + PS) and ports need to match") + else: + if self._num_workers > 1: + raise ValueError("TWML_HOGWILD_PORTS needs to be set to use hogwild training") + + # Split the number of data threads across multiple workers + num_threads = self._params.get('num_threads') + num_threads_per_worker = int(math.ceil(float(num_threads) / self._num_workers)) + self._params.set_hparam('num_threads', num_threads_per_worker) + + hogwild_task_type = os.environ.get('TWML_HOGWILD_TASK_TYPE') + hogwild_task_id = int(os.environ.get('TWML_HOGWILD_TASK_ID')) + os.environ['TF_CONFIG'] = self._get_cluster_config(hogwild_task_type, hogwild_task_id) + + def _tf_random_seed(self): + """ Returns user set seed and deal with Hogwild multiple seeds """ + tf_random_seed = self._params.get('tf_random_seed', None) + if tf_random_seed is None: + return None + elif self.using_hogwild and os.environ.get('TWML_HOGWILD_TASK_TYPE') == 'worker': + # chief (tf_random_seed), worker_0 (tf_random_seed + 1), worker_1 (tf_random_seed + 2)... + return tf_random_seed + 1 + int(os.environ.get('TWML_HOGWILD_TASK_ID')) + else: + return tf_random_seed + + def check_params(self): + """ Verify that params has the correct key,values """ + param_values = self._params.values() + + if 'train_batch_size' in param_values: + if not isinstance(self._params.train_batch_size, int): + raise ValueError("Expecting params.train_batch_size to be an integer.") + if self._params.train_batch_size <= 0: + raise ValueError("train_batch_size needs to be positive") + else: + raise ValueError("train_batch_size needs to be present in params") + + if 'eval_batch_size' in param_values: + if not isinstance(self._params.eval_batch_size, int): + raise ValueError("Expecting params.eval_batch_size to be an integer.") + if self._params.eval_batch_size <= 0: + raise ValueError("eval_batch_size needs to be positive.") + else: + self._params.add_hparam('eval_batch_size', self._params.train_batch_size) + + if (self._params.get('distributed_training_cleanup') and + not self._params.get('distributed')): + # we only need to support training discontinuation for distributed training + # bc we are still using TSDs on GKE for distributed training + raise ValueError( + "Expecting params.distributed to be set if " + "params.distributed_training_cleanup is set." + ) + + def _get_cluster_config(self, name, index): + """Create a tensorflow cluster config from ports, name and index""" + host = '"localhost:%d"' + ps = host % self._ports[0] + chief = host % self._ports[1] + workers = ", ".join([host % port for port in self._ports[2:]]) + config = _CLUSTER_TEMPLATE.substitute( + PS=ps, + CHIEF=chief, + WORKER=workers, + TYPE=name, + INDEX=index, + ) + return config + + @property + def current_estimator_spec(self): + """ + returns the current estimator (warning: often reset) + """ + return self._current_estimator_spec + + @property + def estimator(self): + """ returns estimator encapsulated by Trainer """ + return self._estimator + + @property + def num_workers(self): + """ returns number of workers """ + return self._estimator.config.num_worker_replicas + + @property + def worker_index(self): + """ + returns index of worker in the cluster + chief has index 0 + non-chief workers have indices 1 through (num_workers - 1) + """ + return self._estimator.config.global_id_in_cluster + + @property + def using_hogwild(self): + """ returns a bool indicating whether hogwild is being used """ + return self._using_hogwild + + def set_estimator(self, estimator): + """ sets the estimator used internally by Trainer """ + if not isinstance(estimator, tf.estimator.Estimator): + raise ValueError("Expecting tf.estimator.Estimator") + self._estimator = estimator + self._params = self.estimator.params + + @property + def params(self): + """ + returns the hyper-parameters passed to the constructor. + """ + return self._params + + @staticmethod + def add_parser_arguments(): + """ + Add common commandline args to parse for the Trainer class. + Typically, the user calls this function and then parses cmd-line arguments + into an argparse.Namespace object which is then passed to the Trainer constructor + via the params argument. + + See the `code <_modules/twml/argument_parser.html#get_trainer_parser>`_ + for a list and description of all cmd-line arguments. + + Returns: + argparse.ArgumentParser instance with some useful args already added. + """ + return twml.argument_parser.get_trainer_parser() + + @staticmethod + def get_train_op(params, loss): + """ + Return a training Op, that is, a `twml.optimizers.optimize_loss + `_ + instance given params and loss. + This method can be overwritten by passing the optimize_loss_fn to the Trainer + constructor. + + Args: + params: + tensorflow.contrib.training.HParams instance. Recognizes the optimizer, optimizer_summaries, + gradient_noise_scale, clip_gradients and learning_rate_decay (including + other learning rate decay arguments). + loss: + scalar Op returned by the build_graph that specifies the training loss to + be minimized. + """ + optimizer = params.get('optimizer') + + if not optimizer: + optimizer = 'SGD' + + if optimizer == 'LazyAdam': + optimizer = LazyAdamOptimizer + + if optimizer == 'DGC': + optimizer = DeepGradientCompressionOptimizer( + learning_rate=params.learning_rate, + use_locking=False, + name="Sparse", + density=params.get('dgc_density'), + density_decay=params.get('dgc_density_decay'), + density_decay_steps=params.get('dgc_density_decay_steps'), + density_decay_rate=params.get('dgc_density_decay_rate'), + min_density=params.get('dgc_min_density'), + accumulation=params.get('dgc_accumulation') + ) + + summaries = ['loss'] + if params.get('show_optimizer_summaries'): + summaries = OPTIMIZER_SUMMARIES + + train_op = optimize_loss( + loss=loss, + global_step=tf.train.get_global_step(), + optimizer=optimizer, + learning_rate=params.learning_rate, + summaries=summaries, + colocate_gradients_with_ops=True, + gradient_noise_scale=params.get('gradient_noise_scale'), + clip_gradients=params.get('clip_gradients'), + learning_rate_decay_fn=twml.learning_rate_decay.get_learning_rate_decay_fn(params) + ) + return train_op + + def export_model_effects(self, export_path, feature_spec=None, log_features=True): + + # DO NOT CHANGE THE ORDER. + # This needs to be done before registering the model. + if feature_spec: + if log_features: + features = feature_spec['features'] + feature_names = ['.'.join(features[fid]['featureName'].split('.')[1:]) for fid in features.keys()] + features_to_log = ','.join(feature_names) + try: + model_hash = self.experiment_tracker.compute_model_hash(export_path) + metrics.log_usage('dbv2', 'export_model_effects', 'v1', custom_attrs=[model_hash, "feature config present", features_to_log]) + except: # noqa: T803 + logging.info("Failed to log Feature Config features") + + twml.contrib.export.export_fn.export_feature_spec(export_path, feature_spec) + export_start_time = time.time() + self.experiment_tracker.export_feature_spec(feature_spec) + logging.info("Exported feature spec to ML Metastore in %s seconds.", time.time() - export_start_time) + + self.experiment_tracker.register_model(str(export_path)) + self.export_gauge.increment() + + @property + def best_or_latest_checkpoint(self): + if self._is_early_stopping: + best_checkpoint_path = os.path.join(self._save_dir, "best_checkpoint") + checkpoint_path = tf.train.latest_checkpoint(best_checkpoint_path) + # Return best checkpoint if necessary + if checkpoint_path: + return checkpoint_path + else: + raise ValueError("Best checkpoint not found at %s." % best_checkpoint_path) + else: # Fallback to latest checkpoint from save directory + return self.latest_checkpoint + + @property + def latest_checkpoint(self): + return self.estimator.latest_checkpoint() + + def export_model(self, serving_input_receiver_fn, + export_output_fn=None, + export_dir=None, checkpoint_path=None, + feature_spec=None, + log_features=True): + """ + Export the model for prediction. Typically, the exported model + will later be run in production servers. This method is called + by the user to export the PREDICTgraph to disk. + + Internally, this method calls `tf.estimator.Estimator.export_savedmodel + `_. + + Note that a valid self._export_output_fn is required. + If export_ouput_fn is provided, it is used to set the self._export_output_fn. + + Args: + serving_input_receiver_fn: + function preparing the model for inference requests. + This funtion returns the ``features`` dict passed to ``build_graph``. + export_dir: + directory to export a SavedModel for prediction servers. + Defaults to ``[save_dir]/exported_models``. + checkpoint_path: + the checkpoint path to export. If None (the default), the most recent checkpoint + found within the model directory is chosen. + export_output_fn: + Function to export the graph_output (output of build_graph) for + prediction. Takes a graph_output dict as sole argument and returns + the export_output_fns dict. + Defaults to `twml.export_output_fns.default_output_fn`. + + Return: + returns a string path to exported directory. + + # set the export output function + """ + if not self.is_chief(): + logging.info("Trainer.export_model ignored due to the process not being chief.") + return + + self._export_output_fn = export_output_fn or twml.export_output_fns.default_output_fn + + if not callable(self._export_output_fn): + raise RuntimeError( + "Expecting export_output_fn function. Got %s." + % type(self._export_output_fn).__name__) + + if export_dir: + export_dir = sanitize_hdfs_path(export_dir) + + if checkpoint_path: + checkpoint_path = sanitize_hdfs_path(checkpoint_path) + else: + checkpoint_path = self.best_or_latest_checkpoint + + # actually export the model using the Estimator API + export_path = self._estimator.export_savedmodel( + export_dir_base=export_dir or os.path.join(self._save_dir, 'exported_models'), + serving_input_receiver_fn=serving_input_receiver_fn, + checkpoint_path=checkpoint_path) + + # export_path is bytes, need to convert to string for python3 to work. + logging.info("The exported model path is: " + str(export_path)) + + self.export_model_effects(export_path, feature_spec, log_features) + + return export_path + + def _model_fn(self, features, labels, mode, params, config=None): + """ + returns tf.estimator.EstimatorSpec that can be used with tf.estimator.Estimators. + You would probably never need to modify this method. + Instead, you should override build_graph, which this method calls. + + Args: + features: + Dict of input tensors. + labels: + Tensor of target labels. + mode: + an instance of tf.estimator.ModeKeys. + Typically used to toggle TRAINing or EVALuation. + params: + HParams object containing hyper-parameters. + """ + # pylint: disable=too-many-branches + if isinstance(features, dict): + weights = features.get('weights', None) + else: + weights = None + + with tf.variable_scope(self._name + '/model'): + graph_output = self._build_graph_fn(features, labels, mode, params, config) + loss = graph_output['loss'] if 'loss' in graph_output else None + + self._maybe_restore_checkpoint() + + with tf.variable_scope(self._name + '/optim'): + train_op = None + if mode == tf.estimator.ModeKeys.TRAIN: + if 'train_op' in graph_output: + train_op = graph_output['train_op'] + graph_output['train_op'] = None # remove from preds to prevent error + elif loss is not None: + train_op = self._optimize_loss_fn(params, loss) + + if params.get('train_log_metrics') and self._metric_fn: + metric_ops = self._metric_fn(graph_output=graph_output, labels=labels, weights=weights) + for metric_name in metric_ops: + tf.summary.scalar( + name="training_metric_" + metric_name, + tensor=metric_ops[metric_name][1]) # index 0 contains value_op, 1 contains update_op + + if mode == tf.estimator.ModeKeys.PREDICT and self._export_output_fn is not None: + # note that this is ignored by the predict method. + # Estimator only uses export_output_fn for export_model. + export_outputs = self._export_output_fn(graph_output) + else: + export_outputs = None + + if mode == tf.estimator.ModeKeys.EVAL and self._metric_fn: + eval_metric_ops = self._metric_fn(graph_output=graph_output, labels=labels, weights=weights) + else: + eval_metric_ops = None + + # None and loss (scalar, not sliceable by TFMA) should be removed from the graph_output + preds = {key: graph_output[key] for key in graph_output if (graph_output[key] is not None) and (key is not 'loss')} + + init_feed_dict = twml.contrib.initializers.get_init_feed_dict() + scaffold = tf.train.Scaffold(init_feed_dict=init_feed_dict) + + # Clear the init feed collection to avoid serializing the initializers. + twml.contrib.initializers.clear_init_feed_collection() + + # save estimator for use by later methods and hooks (warning: often reset) + self._current_estimator_spec = tf.estimator.EstimatorSpec( + mode=mode, + predictions=preds, + export_outputs=export_outputs, + loss=loss, + train_op=train_op, + eval_metric_ops=eval_metric_ops, + scaffold=scaffold, + ) + + return self._current_estimator_spec + + def get_train_hooks(self): + """Return SessionRunHooks used during training. + + By default training uses one hooks `tf.train.StepCounterHook` for monitoring step speed. + + If self._profiler_steps is set then we also use the ProfilerHook `tf.train.ProfilerHook` + for monitoring the profile. + + """ + # Instead of having every_n_steps be a constant number, + # change it dynamically based on batch size. + # Ideally we should be using every_n_secs, but that seems buggy as of 1.7. + # The every_n_steps = 20K / batch_size + every_n_steps = ((2048 * 100) // self._params.train_batch_size) + step_counter = tf.train.StepCounterHook( + every_n_steps=every_n_steps, output_dir=self._save_dir + ) + train_hooks = [step_counter] + + if self._profiler_steps is not None: + if not self._params.get('distributed') or self._estimator.config.is_chief: + profiler = tf.train.ProfilerHook( + save_steps=self._profiler_steps, + output_dir=self._save_dir + ) + train_hooks.append(profiler) + + return train_hooks + + def is_task_type(self, name): + """ + Helper function to specify if the current process is of the given worker type. + Note: This an only be called *after* self._hogwild_setup() is called in __init__() + """ + if os.environ.get('TF_CONFIG'): + if self._estimator.config.task_type == name: + return True + else: + return False + return True + + def is_evaluator(self): + """ + Helper function to let you know if the worker is evaluator. + Note: This an only be called *after* self._hogwild_setup() is called in __init__() + """ + return self.is_task_type("evaluator") + + def is_chief(self): + """ + Helper function to let you know if the worker is chief. + Note: This an only be called *after* self._hogwild_setup() is called in __init__() + """ + return self.is_task_type("chief") or self.is_task_type("master") + + def is_ps(self): + """ + Helper function to let you know if the task is parameter server. + """ + if os.environ.get('TF_CONFIG') and self._estimator.config.task_type == 'ps': + return True + return False + + def _exit_ps_after_training_complete(self): + """ + Helper function to shutdown parameter server after training job complete (either succeed or failed). + """ + if not self.is_ps(): + return + + # No need to exit ps if on the same machine + if os.environ.get('TWML_HOGWILD_PORTS'): + return + + if self._params.get('disable_auto_ps_shutdown', False): + logging.info("Skip shutting down parameter server after training complete [--disable_auto_ps_shutdown is set]") + return + + # checking job status is different on gke vs aurora + if self._is_on_gke(): + get_job_status = functools.partial( + k8s_status.get_training_job_status, + cluster=None, + namespace=os.environ['TWML_JOB_ROLE'], + environment=os.environ['TWML_JOB_ENV'], + job_name=os.environ['TWML_JOB_NAME'], + using_tsd=True) + else: + get_job_status = functools.partial( + get_distributed_training_job_path, + base_job_path=get_distributed_training_job_path() + ) + + def wait_complete_then_exit(): + retry_max = 60 + retry = 0 + while True: + try: + training_status = get_job_status() + if training_status == TrainingJobStatus.FINISHED: + logging.info("Distributed training job succeed, shutting down parameter server.") + os._exit(0) + elif training_status == TrainingJobStatus.FAILED: + logging.info("Distributed training job failed, shutting down parameter server.") + os._exit(0) + elif training_status == TrainingJobStatus.NOT_FOUND: + raise Exception("Distributed training job status not found.") + else: + poke_interval = random.randrange(60, 90) # prevent spike QPS to aurora endpoint + time.sleep(poke_interval) + retry = 0 + except Exception as e: + if retry >= retry_max: + raise e # only exception in this thread, won't fail parameter server thread + retry += 1 + poke_interval = random.randrange(60, 90) + retry * 10 + logging.warn("Error getting distributed training job status, will retry after %s seconds." % poke_interval) + time.sleep(poke_interval) + Thread(target=wait_complete_then_exit).start() + + def get_eval_hooks(self): # pylint: disable=no-self-use + """ Return SessionRunHooks used during evaluation.""" + return None + + def get_predict_hooks(self): + """ Return hooks used during prediction. + If profiler_steps is set in the constructor to the Trainer, + we pass a tf.Train.ProfilerHook to the estimator's predict function. + """ + hooks = [] + if self._profiler_steps is not None: + profiler = tf.train.ProfilerHook( + save_steps=self._profiler_steps, + output_dir=self._save_dir + ) + hooks.append(profiler) + return hooks + + def learn(self, train_input_fn=None, eval_input_fn=None, + train_max_steps=None, + train_steps=None, eval_steps=None, + train_hooks=None, eval_hooks=None, + early_stop_metric=None, early_stop_patience=-1, + early_stop_minimize=True, early_stop_tolerance=0, start_epoch=0, + exporters=None, export_output_fn=None, max_duration=None): + """ + Train and evaluate the estimator for ``train_max_steps`` steps. + Each epoch involves ``train_steps`` training steps followed + by ``eval_steps`` evaluation steps. Note that each step + is a ``session.run()``, that is, each batch is a step. + + Args: + train_max_steps: + maximum number of global steps of training to run. + Defaults to params.train_max_steps. + None-values cause learn() to terminate after *one* call to train() and evaluate(), + which is usually useful when using train_steps=-1 + Non-positive values trains indefinitely in a loop (use with caution), + which is usually useful when used with early stopping. + train_steps: + number of training steps per epoch. For example, 100 means each + training epoch will end after processing 100 batches. + Defaults to params.train_steps. + Non-positive values and None-values go through the entire training set each epoch. + eval_steps: + number of evaluation steps per epoch. + Defaults to params.eval_steps. + Non-positive values and None-values go through the entire evaluation set each epoch. + train_input_fn: + Function to iterate through training set. It is passed to estimator.train. + eval_input_fn: + Function to iterate through evaluation set. It is passed to estimator.evaluate. + train_hooks: + List of SessionRunHooks uses for training. Defaults to self.get_train_hooks(). + eval_hooks: + List of SessionRunHooks uses for evaluation. Defaults to self.get_eval_hooks() + start_epoch: + The epoch from which to start learn. If you want to do training and evaluation + for N epochs, you can call ``learn()`` in a loop as follows: + exporters: + List of exporters called at the end of each evaluation run. + Defaults to none. + export_output_fn: + The output format to use for exported models. + Only used if exporters is not None. + + .. code-block:: python + + for epoch in range(1,max_epoch): + trainer.learn(start_epoch=epoch) + + Early-stopping arguments: + early_stop_metric: + String specifying the metric to early-stop on. Required with positive + ``early_stop_patience``. For example, 'accuracy', 'accuracy_0', 'loss', etc. + The string is used to extract the relevant tensor Op from the dict returned by + the get_eval_metric_ops method. For ``metrics`` pass to the constructor, + the string is one of those. For multi-class (that is, multi-metric) + metrics, the string may be appended with a ``_0``, ``_1``, etc. or one + of the ``multi_metric_names`` (one per class). + early_stop_patience: + Maximum number of epochs to wait for an improvement in the early_stop_metric + before breaking off training. For example, a patience of 10 means that + training will have 10 epochs to improve the metric before it is killed. + Whenever the metric is improved before running out of patience, + patience is reset to ``early_stop_patience``. + Defaults to -1 (that is, no early-stopping). + early_stop_minimize: + Set this to True (the default) for metrics that need to be minimized + (like ``loss``). Metrics like ``accuracy`` that need to be maximized + should set this to False. + early_stop_tolerance: + A non-negative tolerance for comparing early_stop_metric. + E.g. when maximizing the condition is current_metric > best_metric + tolerance. + Defaults to 0. + max_duration: + A float. When this argument is defined, the job will automatically terminate after + `max_duration` seconds if it has not already compeleted. + + Returns: + The directory where the checkpoints were saved. + That is, save_dir. + You can point TensorBoard to this directory to get metrics, + or pass it to another Trainer via ``init_from_dir`` when doing + multi-phase training. + """ + # pylint: disable=too-many-branches + + if not callable(train_input_fn): + raise ValueError("Expecting callable train_input_fn function") + if not callable(eval_input_fn): + raise ValueError("Expecting callable eval_input_fn function") + + if os.environ.get('TF_CONFIG'): + raise ValueError("trainer.learn() can not be used with distributed / hogwild setups") + + if exporters and export_output_fn: + self._export_output_fn = export_output_fn + + train_hooks = self.get_train_hooks() if train_hooks is None else train_hooks + eval_hooks = self.get_eval_hooks() if eval_hooks is None else eval_hooks + eval_hooks = [] if eval_hooks is None else eval_hooks + + if train_max_steps is None: + train_max_steps = self.params.get('train_max_steps') + + if train_steps is None: + train_steps = self.params.train_steps + if train_steps <= 0: + train_steps = None + + if eval_steps is None: + eval_steps = self.params.eval_steps + if eval_steps <= 0: + eval_steps = None + + if early_stop_patience > 0: + assert train_max_steps is not None, "Early stopping and max_steps=None are not compatible." + # prepare early stopping hook (which also handles logic here) + self._is_early_stopping = True + early_stop_hook = twml.hooks.EarlyStopHook( + metric=early_stop_metric, + checkpoint_dir=self._save_dir, + patience=early_stop_patience, + minimize=early_stop_minimize, + tolerance=early_stop_tolerance, + get_estimator_spec_fn=lambda: self.current_estimator_spec, + start_epoch=start_epoch) + # add early stop hook to eval hooks + eval_hooks.append(early_stop_hook) + + if max_duration is not None: + train_early_stop_duration_hook = twml.hooks.EarlyStopDuration( + max_duration=max_duration, + exit_on_end=False, + save_dir=self._save_dir, + overwrite=True, + ) + train_hooks.append(train_early_stop_duration_hook) + + eval_early_stop_duration_hook = twml.hooks.EarlyStopDuration( + max_duration=max_duration, + exit_on_end=False, + save_dir=self._save_dir, + overwrite=True, + ) + eval_hooks.append(eval_early_stop_duration_hook) + + if not self._is_early_stopping: + if (train_max_steps is not None) and (train_max_steps <= 0): + if ((max_duration is not None) and (max_duration < 0)) or (max_duration is None): + logging.warn("train.max_steps is non-positive, and no early or duration stopping is configured. " + "Training job will loop forever.") + + if train_max_steps is not None and train_max_steps > 0: + # we can't pass max_steps AND steps to estimator.train. + # so we pass steps to estimator.train and max_steps to this hook instead... + stop_at_step_hook = twml.hooks.StopAtStepHook(last_step=train_max_steps) + train_hooks.append(stop_at_step_hook) + + with self.experiment_tracker.track_experiment(eval_hooks, + lambda: self.current_estimator_spec): + # alternate training and evaluation epochs + epoch = start_epoch + while True: + logging.info("Training epoch %d", epoch) + self._estimator.train(train_input_fn, steps=train_steps, hooks=train_hooks) + + logging.info("Evaluating epoch %d", epoch) + eval_result = self._estimator.evaluate( + eval_input_fn, steps=eval_steps, hooks=eval_hooks) + + if exporters: + checkpoint_path = self.estimator.latest_checkpoint() + for exporter in exporters: + export_path = os.path.join(self._save_dir, "export", exporter.name) + exporter.export( + estimator=self.estimator, export_path=export_path, + checkpoint_path=checkpoint_path, eval_result=eval_result, + is_the_final_export=False) + + # If train_max_step is none. Terminate after one loop. + if train_max_steps is None: + break + + # If stop_at_step_hook requested a stop, break + if train_max_steps > 0 and stop_at_step_hook.stop_requested: + break + + # early-stopping logic is handled internally by the hook + if early_stop_patience > 0 and early_stop_hook.should_stop: + # but we still need to break here + break + epoch += 1 + + self.write_state_to_disk(save_dir=self._save_dir, filename='_SUCCESS') + + return self._save_dir + + def get_train_spec(self, input_fn, max_steps=None, hooks=None): + """Get the TrainSpec used by ``tf.train.train_and_evaluate``.""" + if not callable(input_fn): + raise ValueError("Expecting callable train_input_fn") + + if max_steps is None: + max_steps = self.params.train_max_steps + + if max_steps is not None and max_steps <= 0: + max_steps = None + + hooks = self.get_train_hooks() if hooks is None else hooks + + return tf.estimator.TrainSpec(input_fn=input_fn, + max_steps=max_steps, + hooks=hooks) + + def get_eval_spec(self, input_fn, steps=None, delay=None, period=None, + hooks=None, exporters=None): + """Get the EvalSpec used by ``tf.train.train_and_evaluate``.""" + if not callable(input_fn): + raise ValueError("Expecting callable eval_input_fn") + + if steps is None: + steps = self.params.eval_steps + + if steps <= 0: + steps = None + + if delay is None: + delay = self.params.eval_delay + + if period is None: + period = self.params.eval_period + + hooks = self.get_eval_hooks() if hooks is None else hooks + + eval_name = self.params.get("eval_name", None) + + return tf.estimator.EvalSpec(input_fn=input_fn, + steps=steps, + name=eval_name, + start_delay_secs=delay, + throttle_secs=period, + hooks=hooks, + exporters=exporters) + + def train_and_evaluate(self, train_input_fn=None, eval_input_fn=None, + train_max_steps=None, eval_steps=None, + eval_delay=None, eval_period=None, + train_hooks=None, eval_hooks=None, + early_stop_metric=None, early_stop_patience=-1, + early_stop_minimize=True, early_stop_tolerance=0, exporters=None, + export_output_fn=None, max_duration=None): + """ + Train and evaluate the estimator for ``train_max_steps`` + using ``tf.estimator.train_and_evaluate``. + With a cluster configuration provided in the ``TF_CONFIG`` environment variable, this method + can be used for distributed training (multi-node or multi-process). + Unlike the ``learn`` method, training is continuous with ``train_max_steps``. + For distributed use case, evaluation happens periodically. + That is, after ``eval_delay`` seconds, an evaluation epoch of ``eval_step`` steps + occurs every ``eval_period`` seconds. Evaluation happens on the most recent checkpoint. + TF defaults to saving checkpoints every 10 mins. + For local use case, training occurs for train_max_steps epochs followed by a + single evaluation. For local use case we therefore recommend using learn() instead + as it provides early-stopping and multiple evaluations. + + ``train_and_evaluate`` will evaluate for ``eval_steps`` every ``eval_period`` seconds. + It will stop after ``train_steps`` is reached. + + You must ensure that all workers/servers are assigned the same `save_dir`. + + .. Note:: + + If the TF_CONFIG environment variable is set, this function assumes its running a distribute job. + + Args: + train_input_fn: + Function to iterate through training set. It is passed to estimator.train_and_evalute + eval_input_fn: + Function to iterate through evaluation set. It is passed to estimator.train_and_evalute. + train_max_steps: + maximum number of global steps of training to run. + Defaults to params.train_max_steps. + Non-positive values and None-values train indefinitely (use with caution). + eval_steps: + number of steps per evaluation. + Defaults to params.eval_steps. + Non-positive values and None-values go through + the entire evaluation set for each evaluation. + Note that the number of eval_steps should be high enough to minimize noise. + This is especially true for early-stopping. + eval_delay: + Start the first evaluation after eval_delay. Defaults to params.eval_delay or 2*60s. + eval_period: + Run an evaluation every eval_period seconds. Defaults to params.eval_period or 10*60s. + exporters: + List of exporters called at the end of each evaluation run. + Defaults to none. + export_output_fn: + The output format to use for exported models. + Only used if exporters is not None. + + Early-stopping arguments: + early_stop_metric: + String specifying the metric to early-stop on. Required with positive + ``early_stop_patience``. For example, 'accuracy', 'accuracy_0', 'loss', etc. + The string is used to extract the relevant tensor Op from the dict returned by + the get_eval_metric_ops method. For ``metrics`` pass to the constructor, + the string is one of those. For multi-class (that is, multi-metric) + metrics, the string may be appended with a ``_0``, ``_1``, etc. or one + of the ``multi_metric_names`` (one per class). + early_stop_patience: + Maximum number of epochs to wait for an improvement in the early_stop_metric + before breaking off training. For example, a patience of 10 means that + training will have 10 epochs to improve the metric before it is killed. + Whenever the metric is improved before running out of patience, + patience is reset to ``early_stop_patience``. + Defaults to -1 (that is, no early-stopping). + early_stop_minimize: + Set this to True (the default) for metrics that need to be minimized + (like ``loss``). Metrics like ``accuracy`` that need to be maximized + should set this to False. + early_stop_tolerance: + A non-negative tolerance for comparing early_stop_metric. + E.g. when maximizing the condition is current_metric > best_metric + tolerance. + Defaults to 0. + max_duration: + A float. When this argument is defined, the job will automatically terminate after + `max_duration` seconds if it has not already compeleted. + + Returns: + The directory where the checkpoints were saved. + """ + + logging.info("WARNING: Trainer.train_and_evaluate is an EXPERIMENTAL API.") + logging.info("Trainer.train_and_evaluate may change or be removed in future versions.") + + if not callable(train_input_fn): + raise ValueError("Expecting callable train_input_fn function") + if not callable(eval_input_fn): + raise ValueError("Expecting callable eval_input_fn function") + + self._exit_ps_after_training_complete() + + # Maybe export in eval processes. + if self.is_evaluator(): + if self.params.get("eval_name") is not None: + # Do not export if running special eval. + exporters = None + export_output_fn = None + elif exporters and export_output_fn: + self._export_output_fn = export_output_fn + else: + # Default option. + self._export_output_fn = None + + train_hooks = self.get_train_hooks() if train_hooks is None else train_hooks + train_hooks = [] if train_hooks is None else train_hooks + + eval_hooks = self.get_eval_hooks() if eval_hooks is None else eval_hooks + eval_hooks = [] if eval_hooks is None else eval_hooks + + if train_max_steps is None: + train_max_steps = self.params.get('train_max_steps') + + if eval_steps is None: + eval_steps = self.params.eval_steps + if eval_steps <= 0: + eval_steps = None + + if eval_delay is None: + eval_delay = self.params.eval_delay + if eval_period is None: + eval_period = self.params.eval_period + + if early_stop_patience > 0: + # when training hooks detect this file, they request a stop to training + early_stop_path = os.path.join(self._save_dir, 'earlystop_now.txt') + # prepare early stopping hook (which also handles logic here) + + self._is_early_stopping = True + + eval_early_stop_hook = twml.hooks.EarlyStopHook( + metric=early_stop_metric, + checkpoint_dir=self._save_dir, + patience=early_stop_patience, + minimize=early_stop_minimize, + tolerance=early_stop_tolerance, + get_estimator_spec_fn=lambda: self.current_estimator_spec, + file_path=early_stop_path, + exit_on_end=os.environ.get('TF_CONFIG') is not None) # only exit for distributed jobs + # add early stop hook to eval hooks + eval_hooks.append(eval_early_stop_hook) + + # prepare the commensurate training hook + train_early_stop_hook = twml.hooks.StopIfExistsHook(early_stop_path) + train_hooks.append(train_early_stop_hook) + + if max_duration is not None: + train_early_stop_duration_hook = twml.hooks.EarlyStopDuration( + max_duration=max_duration, + exit_on_end=False, + save_dir=self._save_dir, + overwrite=self.is_chief() + ) + eval_early_stop_duration_hook = twml.hooks.EarlyStopDuration( + max_duration=max_duration, + exit_on_end=os.environ.get('TF_CONFIG') is not None, + save_dir=self._save_dir, + overwrite=False + ) # only exit for distributed jobs + + train_hooks.append(train_early_stop_duration_hook) + eval_hooks.append(eval_early_stop_duration_hook) + + with self.experiment_tracker.track_experiment(eval_hooks, lambda: self.current_estimator_spec): + train_spec = self.get_train_spec(train_input_fn, train_max_steps, train_hooks) + eval_spec = self.get_eval_spec(eval_input_fn, eval_steps, + eval_delay, eval_period, + eval_hooks, exporters) + self._train_and_evaluate(train_spec, eval_spec) + + if self.is_chief(): + self.write_state_to_disk(save_dir=self._save_dir, filename='_SUCCESS') + + return self._save_dir + + def _train_and_evaluate(self, train_spec, eval_spec): + """ + Private method that calls + ``tf.estimator.train_and_evaluate(self._estimator, train_spec, eval_spec)``. + """ + try: + tf.estimator.train_and_evaluate(self._estimator, train_spec, eval_spec) + except twml.errors.EarlyStopError: + # Ignore the exception if on evaluator. + if self.is_evaluator(): + pass + else: + raise + + def train(self, input_fn=None, steps=None, hooks=None): + """ + Train the estimator for `steps` training steps. + + Args: + steps: + number of steps for which to perform training. For example, 100 means each + evaluation will end after processing 100 batches. + Defaults to None. i.e. trains on the entire dataset a single time. + Non-positive values and None-values go through the entire training set each epoch. + input_fn: + Function to iterate through training set. It is passed to estimator.train. + hooks: + List of SessionRunHooks uses for training. Defaults to self.get_train_hooks(). + """ + if os.environ.get('TF_CONFIG') and "is_calibrating" not in self.params: + raise ValueError("trainer.train() can not be used with distributed / hogwild setups") + + if not callable(input_fn): + raise ValueError("Expecting callable input_fn function") + + if self._is_early_stopping: + raise ValueError("Can not call train() after learn() when using early stopping.") + + hooks = self.get_train_hooks() if hooks is None else hooks + self._estimator.train(input_fn, steps=steps, hooks=hooks) + return self + + def evaluate(self, input_fn=None, steps=None, hooks=None, name=None): + """ + Evaluate the estimator for `steps` evaluation steps. + + Args: + steps: + number of steps for which to perform evaluation. For example, 100 means each + evaluation will end after processing 100 batches. + Defaults to None. i.e. evaluates on the entire dataset a single time. + Negative values and None-values go through the entire training set each epoch. + input_fn: + Function to iterate through evaluation set. It is passed to estimator.evaluate. + hooks: + List of SessionRunHooks used for evaluation. Defaults to None. + Note that, unlike learn(), hooks defaults to None instead of self.get_eval_hooks() + as the latter may implement early-stopping, which isn't necessarilty the desired + behavior when calling evaluate() on its own. + name: + Name of the evaluation if user needs to run multiple evaluations on different data sets. + Metrics for different evaluations are saved in separate folders, + and appear separately in tensorboard. + + Returns: + If `is_evaluator()`, returns a dict containing the evaluation metrics specified + in `metric_fn` keyed by name, as well as an entry `global_step` that contains + the value of the global step for which this evaluation was performed. + Otherwise (i.e. `is_evaluator() == False`), returns None. + """ + if not self.is_evaluator(): + return None + + if not callable(input_fn): + raise ValueError("Expecting callable input_fn function") + + hooks = self.get_eval_hooks() if hooks is None else hooks + hooks = [] if hooks is None else hooks + + # for consistency with train/learn + eval_steps = None if steps is not None and steps < 0 else steps + + with self.experiment_tracker.track_experiment(hooks, lambda: self.current_estimator_spec, name=name): + checkpoint = self.best_or_latest_checkpoint + computed_metrics = self._estimator.evaluate( + input_fn, + steps=eval_steps, + hooks=hooks, + checkpoint_path=checkpoint, + name=name + ) + + return computed_metrics + + def start_tensorboard(self, port=None): + """ + Start tensorboard process to visualize logs in save_dir. + """ + logging.info("Starting tensorboard.") + if self._tensorboard_handle: + logging.warn("Tensorboard already running. Nothing done.") + return + + if port is None: + if 'tensorboard_port' not in self.params.values(): + raise ValueError('You must specify a port for tensorboard to run on.') + elif self.params.tensorboard_port is None: + return + else: + port = self.params.tensorboard_port + + mldash_path = 'experiments' + if self.experiment_tracker.path: + mldash_path += '/%s' % encode_url(self.experiment_tracker.experiment_id) + tensorboard_args = ['--logdir=%s' % self._save_dir, '--port=%d' % port] + + try: + args = ['email_and_launch_tensorboard', mldash_path, '--'] + tensorboard_args + self._tensorboard_handle = subprocess.Popen(args) + except OSError: + try: + self._tensorboard_handle = subprocess.Popen(['tensorboard'] + tensorboard_args) + except OSError: + try: + # this will work with Twitter internal pants build when run locally + args = ['./pants', 'run', 'twml:tensorboard', '--'] + tensorboard_args + self._tensorboard_handle = subprocess.Popen(args) + except OSError: + logging.error("No tensorboard installed, won't able to visualize training in tensorboard.") + + def stop_tensorboard(self): + """ + Shutdown this Trainer's associated Tensorboard. + """ + if self._tensorboard_handle: + logging.info("Shutting down tensorboard.") + self._tensorboard_handle.kill() + else: + logging.warn("No known tensorboard process. Nothing done.") + + def calibrate(self, + calibrator, + steps=None, + input_fn=None, + save_calibrator=True, + hooks=None): + """ + Calibrate the calibrator for `steps` calibration steps using the estimator.train method. + The build_graph passed to the Trainer constructor should + call calibrator.accumulate using something like tf.py_func. + That way, when this method calls estimator.train the calibrator will + accumulate one epoch of samples. After which, this method calls calibrator.calibrate(). + It is up to the user to then call calibrator.save() to save the calibrated Layer + and other information to disk for multi-phase training. + + Args: + calibrator: + a twml.Calibrator instance or a dict of the form {name(str): twml.Calibrator}. + steps: + Maximum steps to accumulate examples for calibration. Optional. + If not specified, examples will be accumulated until all downsampled parts are processed. + input_fn: + Function to iterate through training set. It is passed to estimator.train. + hooks: + List of SessionRunHooks uses for training. Defaults to self.get_train_hooks(). + save_calibrator: + Boolean (default: True). If set to True it will save the calibrator layer. + """ + + if not callable(input_fn): + raise ValueError("Expecting callable input_fn function") + + # making everything a dict to avoid multiple ifs + if isinstance(calibrator, twml.contrib.calibrators.Calibrator): + calibrator = {"default": calibrator} + + # This is a dummy call to train, since we cannot predict without training + # from the Estimator API + self._estimator.train(input_fn, steps=1) + max_steps = steps if steps is not None else -1 + for name, clbrt in sorted(calibrator.items(), key=itemgetter(0)): + count = 0 + for out in self._estimator.predict(input_fn, hooks=hooks, yield_single_examples=False): + if max_steps > 0 and count > max_steps: + break + clbrt.accumulate_feature(out) + count += 1 + clbrt.calibrate() + + # this step is done to allow us to keep the current phases event file for + # visualization on Tensorboard. It removes all files that + # are not event files. This piece of code should be deprecated when + # we deprecate the MDL calibrator (CX-12329) + for fname in tf.io.gfile.listdir(self._save_dir): + if not fname.startswith("events"): + tf.io.gfile.remove(os.path.join(self._save_dir, fname)) + + if save_calibrator: + # If we only have one calibrator, the calibrator signature + # will be set to default + if len(calibrator) == 1: + calibrator = calibrator['default'] + calibrator.save( + self.params.save_dir, + name=calibrator.name, + verbose=True + ) + else: + for name, clbrt in calibrator.items(): + clbrt.save( + self.params.save_dir, + name=clbrt.name + str(name), + verbose=True + ) + + def predict(self, *args, **kwargs): + """ + Wrapper over the tensorflow `Estimator.predict + `_. + method. See that documentation for description of arguments accepted. + + If hooks is passed as an argument, the specified hooks are used. + Else when profiler_steps is specified in the constructor of the Trainer, a + tf.train.ProfilerHook is passed to the predict interface. + Otherwise, hooks is set to an empty list. + """ + if 'hooks' not in kwargs and len(args) < 3: + # If hooks is not specified as a keyword argument, nor as a positional argument + # add hooks as a keyword argument. + kwargs['hooks'] = self.get_predict_hooks() + + return self.estimator.predict(*args, **kwargs) + + def hub_export(self, + name, + serving_input_receiver_fn, + export_dir=None, + checkpoint_path=None, + export_task_type_overrider=None): + """ + Exports registered modules into a save directory. + + This method creates a directory under export_path with the save TF Hub. + One sub-directory (named export_name) per module registered via register_module_for_export. + + Arguments: + name: + unique name of the module to export. + serving_input_receiver_fn: + A function with no arguments that returns a ServingInputReceiver. + This is used with the estimator passed to export() to build the graph (in PREDICT mode) + that registers the modules for export. The model in that graph is never run, + so the actual data provided by this input fn does not matter. + export_dir: + A string containing a directory where to write the export directories. + Defaults to the save_dir. + checkpoint_path: + The checkpoint path to export. Defaults to the latest. + export_task_type_overrider: + Specifies the task type that will override the default task type used for export + (hogwild training defaults to evaluator, otherwise, defaults to chief) + """ + if export_task_type_overrider: + if not self.is_task_type(export_task_type_overrider): + logging.info( + f"Trainer.hub_export ignored due to process not being {export_task_type_overrider}") + return + else: + if self._using_hogwild: + if not self.is_evaluator(): + logging.info("Trainer.hub_export ignored due to the process not being evaluator.") + return + else: + if not self.is_chief(): + logging.info("Trainer.hub_export ignored due to the process not being chief.") + return + + if export_dir: + export_dir = sanitize_hdfs_path(export_dir) + + if checkpoint_path: + checkpoint_path = sanitize_hdfs_path(checkpoint_path) + else: + checkpoint_path = self.best_or_latest_checkpoint + + export_dir = export_dir if export_dir is not None else self._save_dir + exporter = hub.LatestModuleExporter(name, serving_input_receiver_fn) + # The path_exporter by default contains a timestamp directory in its path. + path_exporter = exporter.export(estimator=self.estimator, + export_path=export_dir, + checkpoint_path=checkpoint_path) + + # LatestModuleExporter.export() returns a binary string on Cloud ML Engine + # but tf.io.gfile.listdir() does not; this is an issue when joining paths + if isinstance(path_exporter, bytes): + path_exporter = path_exporter.decode() + + # Copying the saved hub module to export_dir so we don't need to specify + # the timestamp when loading the module. + # This is a workaround due to the current implementation of hub.LatestModuleExporter. + # This works for multiple hub modules. + hub_exported_modules = tf.io.gfile.listdir(path_exporter) + + backup_dir = os.path.join(export_dir, "backups", + datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) + + for folder in hub_exported_modules: + hub_module_oldpath = os.path.join(path_exporter, folder) + hub_module_newpath = os.path.join(export_dir, folder) + + # If the destination already exists, move to backup + if tf.io.gfile.exists(hub_module_newpath): + # Ensure backup_dir exists + tf.io.gfile.makedirs(backup_dir) + hub_module_backup = os.path.join(backup_dir, folder) + tf.io.gfile.rename(hub_module_newpath, hub_module_backup) + + tf.io.gfile.rename(hub_module_oldpath, hub_module_newpath) + + # Since the timestamped folder exists but is empty, we can delete it. + tf.io.gfile.rmtree(path_exporter) + + def _is_on_gke(self) -> bool: + """Returns True if running on gke.""" + cluster = os.environ.get('TWML_JOB_CLUSTER') + if not cluster or cluster in {'smf1', 'atla'}: + return False + return True + + def _maybe_del_tsd_exit(self, state_files) -> None: + """Handle potential early exit and TwitterSetDeployment deletion. + + If: + - distributed training + - running GKE + - training is finished (all state_files exists) + we will exit early and not restart work + + If --distributed_training_cleanup = True then we will also handle + cleaning up the TwitterSetDeployments. + + Args: + state_files: A python list indicate state files to determine the finish + state of the job. + """ + # job type that is responsible for experiment tracking will remain alive + # until it marks the experiment as finished. + if self.experiment_tracker._env_eligible_for_recording_experiment: + exp_status = self.experiment_tracker.get_run_status() + if exp_status and exp_status not in {'Success', 'Failed'}: + logging.info( + f"Not exiting early because experiment is still {exp_status}." + ) + return + + # do not bother if we are on prem + if not self._is_on_gke(): + logging.info("No need to exit early because running on prem.") + return + + states = [ + twml.util.file_exist_in_dir(self._save_dir, state_file) for state_file in state_files] + do_not_restart = (self._params.get('distributed') and all(states)) + if not do_not_restart: + return + + logging.info( + f"Exiting early because a _SUCCESS file already exists in {self._save_dir}") + if self._params.get('distributed_training_cleanup'): + resource_name = '-'.join([ + os.environ['TWML_JOB_NAME'], + os.environ['TWML_DISTRIBUTED_JOB_TYPE'], + os.environ['TWML_JOB_ENV'], + ]) + logging.info(f"Deleting TwitterSetDeployment {resource_name}") + # each job type will manage its own deletion so that deletion happens + # in the trainer init call for every job type + # otherwise we may kill another job type during an important + # process like experiment tracking management (handled by the evaluator + kubectl_delete_by_name( + zone=None, + namespace=os.environ['TWML_JOB_ROLE'], + resource_type=Resource.TWITTERSETDEPLOYMENTS.value, + resource_name=resource_name, + wait=False, + ) + sys.exit(0) + + def write_state_to_disk(self, save_dir, filename='_SUCCESS') -> None: + """Write state file to disk to indicate the state of training process. This is usually used + to mark the state of training progress and determine the start when job restarts/resumes. + Args: + save_dir: A str of local/gcs/hdfs dir to write the state file. + file_name: A str indicate the state file. Default to `_SUCCESS`. + """ + file_path = os.path.join(save_dir, filename) + if tf.io.gfile.exists(file_path): + tf.logging.warn(f'{file_path} already exist.') + return + + with tf.io.gfile.GFile(file_path, 'w') as f: + f.write('') \ No newline at end of file diff --git a/twml/twml/util.py b/twml/twml/util.py new file mode 100644 index 000000000..cd7679a6f --- /dev/null +++ b/twml/twml/util.py @@ -0,0 +1,942 @@ +""" +This module contains utility functions for twml. +""" + +import argparse +from datetime import datetime +import itertools +import json +import logging as _logging +import os +import re + +from twitter.ml.common.resources import AuroraPath +from twitter.deepbird.hparam import HParams +from twitter.deepbird.io.util import ( + _get_feature_id, # noqa: F401 + feature_id, # noqa: F401 + preprocess_feature_regex, # noqa: F401 + preprocess_path, # noqa: F401 + sanitize_hdfs_path, # noqa: F401 + is_string, # noqa: F401 + list_files, # noqa: F401 + match_files, # noqa: F401 +) +from twitter.deepbird.io.legacy.util import ( + batch_apply, # noqa: F401 + boolean_mask, # noqa: F401 + fixed_length_tensor, # noqa: F401 +) +from twitter.deepbird.sparse.util import ( + convert_to_sparse, # noqa: F401 + limit_bits, # noqa: F401 +) + +from dateutil import rrule +from joblib import delayed, Parallel +from six import string_types + +from absl import logging +from libtwml import CLIB, OPLIB # noqa: F401 +import tensorflow.compat.v1 as tf +from tensorflow.python.platform import tf_logging +import twml +from twml.feature_config import FeatureConfigBuilder + + +# big_prime is less than 2**32 +# This just needs to be co-prime with powers of 2 +# any large prime is sufficient, but it's not necessary. +HASHING_PRIME = 2479700537 + + +def multiplicative_hash(input, hash_constant=HASHING_PRIME): + return input * hash_constant + + +def _return_tensors_from_checkpoint_folder(init_dir, model_name=None): + """Returns tensors list from a checkpoint folder + + Args: + init_dir: Name of the checkpoint directory. + model_name: the model which we will use to obtain the checkpoint + (e.g. model.ckpt-50000) if set to None it will default to the + latest model saved in the checkpont file. + + """ + if model_name is None: + # gets the most recently generated model.cpkt file + model_path = tf.train.latest_checkpoint(init_dir) + if model_path is None: + raise ValueError("Could not find a valid model checkpoint inside the directory") + else: + model_path = os.path.join(init_dir, model_name) + reader = tf.train.NewCheckpointReader(model_path) + try: + return (reader.debug_string().decode("utf-8")) + except OSError: + logging.error('Could not decode the string') + + +def get_scope_dict(init_dir, incoming_scope_name, current_scope_name, model_name=None): + """Returns tensors map from a checkpoint file. + + Args: + file_name: + Name of the checkpoint directory. + incoming_scope_name: + scope name of the previous phase + current_scope_name: + scope name of current phase + model_name: + the model which we will use to obtain the checkpoint + (e.g. model.ckpt-50000) if set to None it will default + to the latest model saved in the checkpoint file. + Returns: + init_map: + init_map which will be inputted to the checkpoint + """ + init_map = {} + reader_dump = _return_tensors_from_checkpoint_folder(init_dir=init_dir, + model_name=model_name).splitlines() + for member in reader_dump: + # remove global_step since it is not necessary + if 'global_step' not in member: + saved_variables = str(member.split(" ")[0]) + saved_scope = saved_variables.rsplit('/', 1)[0] + "/" + new_scope = saved_scope.replace(incoming_scope_name, current_scope_name, 1) + # create key in init_map + if saved_scope not in init_map.keys(): # pylint: disable=dict-keys-not-iterating + init_map[saved_scope] = new_scope + return init_map + + +def get_init_map( + init_from_dir, + exclude_var_names=None, + exclude_name_scopes=None, + name_scope_to_remove=None, + name_scope_to_prepend=None): + """ + Builds a map for initializing from a checkpoint (see tf.train.init_from_checkpoint). + + It assumes that the latter part of the variable names are consistent between the checkpoint and + the new model, but their name_scopes may be different. If the checkpoint model has variable names + of the form old/scope/var/foo, and the corresponding variable names for the new model should be + my/new/scope/var/foo, then you should set name_scope_to_remove = 'old/' and + name_scope_to_prepend = 'my/new/'. + + This function can be used to + + 1. Generate an ``init_map`` map that can be passed to the ``Trainer`` init or + 2. Used to generate an ``init_map`` directly inside ``build_graph_fn``, in + which case it should be passed directly to ``tf.train.init_from_checkpoint`` inside + ``build_graph_fn``, in which case you do not also need to specify the ``init_map`` argument to + the trainer. + + Parameters + ---------- + init_from_dir: Directory containing checkpoint + exclude_var_names: list[str] + List of variables in the checkpoint that should be excluded from the map. + exclude_name_scopes: list[str] + List of name_scopes in the checkpoint model that should be excluded from the map. + name_scope_to_remove: str + portion of name_scope for checkpoint variables that should not be included in variable names + for new model. + name_scope_to_prepend: str + name_scope to prepend to variable names in checkpoint to give variable names for new model. + + Returns + ------- + dict + keys are variable names in the checkpoint and values are variable names in the new model, + into which the checkpoint parameters should be loaded. + """ + vars_to_restore = get_checkpoint_variable_names( + init_from_dir, + exclude_var_names=exclude_var_names, + exclude_scopes=exclude_name_scopes, + ) + + if name_scope_to_prepend is not None: + if not name_scope_to_prepend.endswith('/'): + name_scope_to_prepend += '/' + + if name_scope_to_remove is not None: + if not name_scope_to_remove.endswith('/'): + name_scope_to_remove += '/' + + init_map = {} + + for var_name in vars_to_restore: + var_name_checkpoint = var_name + + if name_scope_to_remove is not None: + var_name = var_name.replace(name_scope_to_remove, '') + + var_name_new_model = var_name + + if name_scope_to_prepend is not None: + var_name_new_model = name_scope_to_prepend + var_name_new_model + + init_map[var_name_checkpoint] = var_name_new_model + + return init_map + + +def get_checkpoint_variable_names(model_dir, exclude_var_names=None, exclude_scopes=None): + """ + Gets a list of variable names from the latest checkpoint in model_dir. + Removes variables with scope defined by exclude_scopes, and/or with names defined by + exclude_var_names. + + Args: + model_dir (str): Directory containing checkpoint file for the pre-trained model + exclude_var_names (list): Optional variable names to exclude (can include full/partial scope) + exclude_scopes (list): Optional scopes to exclude + + Returns: + list: variable names + """ + checkpoint_path = tf.train.latest_checkpoint(model_dir) + variables_and_shapes = tf.train.list_variables(checkpoint_path) + + def _keep(name): + if exclude_scopes and any(name.startswith(exc_scope) for exc_scope in exclude_scopes): + return False + if exclude_var_names and any(name.endswith(exc_var) for exc_var in exclude_var_names): + return False + return True + + names = [x[0] for x in variables_and_shapes if _keep(x[0])] + + return names + + +def to_snake_case(name): + """ + Changes name to snake case + """ + intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name) + insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower() + # If the class is private the name starts with "_" which is not secure + # for creating scopes. We prefix the name with "private" in this case. + if insecure[0] != '_': + return insecure + return 'private' + insecure + + +def copy_phase_inputs(init_dir, dest_dir): + """Automatically copies the .json.tf from the init_dir to save_dir + so we can load multiple parameters at the same time. + + Args: + init_dir: + Name of the checkpoint directory. + dest_dir: + Name of the output directory. + """ + if init_dir is not None: + # we are using tf.io.gfile so we can use it with both local and hdfs paths + for files in tf.io.gfile.listdir(init_dir): + if files.endswith(".json.tf"): + src_file = os.path.join(init_dir, files) + dest_file = os.path.join(dest_dir, files) + if not tf.io.gfile.exists(dest_dir): + # creates the folder + try: + tf.io.gfile.makedirs(dest_dir) + # to prevent racing condition + except OSError: + if not tf.io.gfile.isdir(dest_dir): + raise + # dest_file may be old if it exists and + # dest_file gets copied several times in distributed training + tf.io.gfile.copy(src_file, dest_file, overwrite=True) + + +def rehash_sparse_features_nbits(sp_a, nbits, hash_fn=multiplicative_hash): + """ + Rehash the feature ids of the sparse tensor, + and limit the output to n bits. + + This is useful for making the distribution of + feature_ids more uniform, which may improve performance + in some situations. + + This would typically be used on the output of + PercentileDiscretizer, since it assigns many + bins to low-valued output feature ids. + + Input feature IDs should take values less than 2**32, + and nbits should be less than 32 + + Args: + sp_a: + a tf.SparseTensor object + nbits: + integer number of bits to mask output feature_ids + hash_fn: + Function that takes integer values and returns hashes of these values. + The output does not need to be masked to the desired number of bits, + as this masking will be taken care of. Default value = multiplicative_hash. + + Returns: + a new tf.SparseTensor + """ + + feature_ids = sp_a.indices[:, 1] + feature_ids = hash_fn(feature_ids) + + sample_ids = sp_a.indices[:, 0] + values = sp_a.values + dense_shape = sp_a.dense_shape + + indices = tf.stack([sample_ids, feature_ids], axis=1) + + sp_a = tf.SparseTensor(indices, values, dense_shape) + + # note - we need 2**nbits >= batch size + # otherwise, sample_ids will be squashed by the mask. + return limit_sparse_tensor_size(sp_a, nbits) + + +def convert_to_hparams(opt): + """ + Converts argparse.Namespace object to twitter.deepbird.hparam.hparam.HParams. + Note that tensorflow.contrib.training.HParams is gone in TF 2.x, and we forward ported + tensorflow.contrib.training.HParams to twitter.deepbird.hparam.hapram.HParams. + + NOTE: If you are using estimators, please don't call this method and directly pass python dict + to TensorFlow estimator. Starting TensorFlow 2.0, Estimator will only accept dicts. + """ + + # Convert to dict so we can iterate through it cleanly. + if isinstance(opt, argparse.Namespace): + params_dict = vars(opt) + elif isinstance(opt, dict): + params_dict = opt + elif isinstance(opt, HParams): + logging.warning('If you are using Estimator, please pass python dict directly to Estimator.') + params_dict = opt.values() + else: + raise ValueError("Input can not be of type %s. " + "It can be one of { argparse.Namespace, dict, " + "twitter.deepbird.hparam.HParams}." + % type(opt)) + + params = HParams() + # Hack to convert all parameters from hdfs:/// format to hdfs://default/ + # Note: .items() makes a copy in python 2.7, but that is fine since the performance isn't critical. + for key, val in params_dict.items(): + val = params_dict[key] + # Fix the path if the value is a string + if isinstance(val, str): + params.add_hparam(key, sanitize_hdfs_path(val)) + else: + params.add_hparam(key, val) + + return params + + +def dynamic_partition(features, partitions, num_partitions=2, name=None): + """ + Partitions each of the tensor in features using the provided mask. + + Args: + features: + A single tensor or an iterable of tensors (list, tuple, dict) + partitions: + A bool or integer tensor representing the partitions. + + Returns partitioned outputs as a list. Each element of the list is the same type as features. + + This uses tf.dynamic_partition but adds the following niceties: + - features can be a list or dict of different tensor types. + - only a partition tensor is used to partition all the feature tensors recursively. + - the partition tensor is automatically converted into an integer tensor. + - defaults to num_partitions == 2 + """ + + if not isinstance(features, (dict, list, tuple, tf.Tensor)): + raise AssertionError("features container must be a dict, list, or tuple, tf.Tensor") + + if isinstance(partitions, tf.Tensor): + partitions = tf.cast(partitions, tf.int32) + + if isinstance(features, tf.Tensor): + return tf.dynamic_partition(features, partitions, num_partitions, name) + + outputs = [] + for _ in range(num_partitions): + if isinstance(features, (tuple, list)): + # Create an empty list of lists first, will be converted to right type afterwards. + outputs.append([None for _ in range(len(features))]) + else: + outputs.append(dict()) + + iterable = features.items() if isinstance(features, dict) else enumerate(features) + + # Handling partitions of nested classes handled here: + # Recursively call dynamic_partition for containers + for key, feature in iterable: + name_key = None if name is None else name + "_" + str(key) + if isinstance(partitions, tf.Tensor): + results = tf.dynamic_partition(feature, partitions, num_partitions, name_key) + else: + results = tf.dynamic_partition(feature, partitions[key], num_partitions[key], name_key) + # Append the result to the proper output container + for idx, result in enumerate(results): + outputs[idx][key] = result + + # if input is tuple, convert list of lists back to list of tuples + if isinstance(features, tuple): + outputs = [type(features)(output) for output in outputs] + + return outputs + + +def write_file(filename, contents, encode=False): + ''' + Optionally encodes contents and writes contents to a file. + + Arguments: + filename: + path to file where the contents will be saved. + Accepts HDFS and local paths. + contents: + contents to save to the file. + Must be a string when encode is False. + encode: + False | 'json'. When encode='json', contents is encoded + with json.dumps. + ''' + if encode == 'json': + contents = json.dumps(contents) + elif not is_string(contents): + raise ValueError("Expecting string for encode=False") + + graph = tf.Graph() + with graph.as_default(): + write = tf.write_file(filename, contents) + + with tf.Session(graph=graph) as sess: + sess.run(write) + + +def read_file(filename, decode=False): + ''' + Reads contents from a file and optionally decodes it. + + Arguments: + filename: + path to file where the contents will be loaded from. + Accepts HDFS and local paths. + decode: + False | 'json'. When decode='json', contents is decoded + with json.loads. When False, contents is returned as is. + + Returns: + contents + ''' + graph = tf.Graph() + with graph.as_default(): + read = tf.read_file(filename) + + with tf.Session(graph=graph) as sess: + contents = (sess.run(read)) + # particular version of TF and/or Python may or may not perform decoding step from utf-8 to str + if not isinstance(contents, str): + contents = contents.decode() + + if decode == 'json': + contents = json.loads(contents) + + return contents + +def setup_tf_logging_formatter(): + formatter = _logging.Formatter( + '%(asctime)s [%(levelname)s] %(name)s: %(message)s', + None) + # Setting up absl logging verbosity + logging.set_verbosity('info') + logging.set_stderrthreshold('info') + logging.get_absl_handler().setFormatter(formatter) + tf.logging.set_verbosity(tf.logging.INFO) + # Set tensorflow logging handler format + if len(tf_logging.get_logger().handlers) > 0: + tf_logging.get_logger().handlers[0].setFormatter(formatter) + + +def set_tensorflow_log_level(log_level): + """ + Sets tensorflow's default logging level. + + 0. all logs are shown. + 1. filter out INFO logs. + 2. filter out WARNINGs and INFOs. + 3. filter out ERRORs, WARNINGs, and INFOs. + + Note that tf.Print output are INFO logs, so setting log_level above 0 would hide + output from tf.Print. + """ + assert isinstance(log_level, int) and log_level >= 0 and log_level <= 3 + os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(log_level) + + +def weighted_average(values, weights): + """ + Compute a weighted average using the given values and weights. + E.g. this is usually used to compute a weighted loss given sample weights. + """ + return tf.reduce_sum(tf.multiply(values, weights)) / tf.reduce_sum(weights) + + +def backup_checkpoint(checkpoint_path_prefix, + backup_path='backup', + empty_backup=True): + """ + Creates a backup copy of a checkpoint in backup_dir. + This function is used by the Trainer for early-stopping. + + Arguments: + checkpoint_path_prefix: + Prefix of the path to the checkpoint files. + backup_path: + path to a directory where checkpoint files will be backed up. + empty_backup: + When True (the default), the current contents of the backup directory + are removed before the backup is performed. + + Returns: + The number of backed up files. + """ + checkpoint_file_prefix = os.path.basename(checkpoint_path_prefix) + + if tf.io.gfile.exists(backup_path) and empty_backup: + tf.io.gfile.rmtree(backup_path) + + tf.io.gfile.mkdir(backup_path) + + n_backup = 0 + # copy all checkpoint files to backup directory (TODO use gfile.glob instead) + try: + checkpoint_files = tf.io.gfile.glob(checkpoint_path_prefix + "*") + if len(checkpoint_files) == 0: + raise twml.errors.CheckpointNotFoundError("%s not found" % checkpoint_path_prefix) + for filename in checkpoint_files: + n_backup += 1 + tf.io.gfile.copy( + src=filename, + dst=os.path.join(backup_path, os.path.basename(filename)) + ) + except tf.errors.OpError as ex: + raise twml.errors.CheckpointNotFoundError( + f"{str(ex)}\n {checkpoint_path_prefix} not found." + ) + + # tf.train.latest_checkpoint needs the 'checkpoint' file. + with tf.io.gfile.GFile(os.path.join(backup_path, 'checkpoint'), 'w') as f: + f.write('model_checkpoint_path: "%s"\n' % checkpoint_file_prefix) + + return n_backup + + +def set_only_checkpoint(source_path, dest_path, remove_source=True): + """ + Removes the checkpoint and model.ckpt* files from dest_path. + Moves the latest checkpoint from source_path to dest_path. + + Arguments: + source_path: + path to directory containing the latest checkpoint. + Should contain a valid checkpoint file and model.ckpt files. + For early-stopping, this should be the save_dir/best_checkpoint dir. + dest_path: + path to directory where the latest checkpoint files will be moved. + All its checkpoint and model.ckpt* files will be removed. + For early-stopping, this should be the save_dir. + remove_source: + When True (the default), deletes the source directory. + Note that even when False, its checkpoint files are moved to + dest_path anyway. + This deletes the source directory (and any remaining contents). + """ + # make it so that source_path checkpoint is the only checkpoint + source_path_prefix = tf.train.latest_checkpoint(source_path) + if source_path_prefix is not None: + # remove intermediate checkpoints + for filename in tf.io.gfile.listdir(dest_path): + if filename.startswith("model.ckpt"): + tf.io.gfile.Remove(os.path.join(dest_path, filename)) + # move contents of source_path to dest_path + for filename in tf.io.gfile.listdir(source_path): + tf.io.gfile.rename( + oldname=os.path.join(source_path, filename), + newname=os.path.join(dest_path, filename), + overwrite=True) # overwrite "checkpoint" file + # delete the source_path dir + if remove_source: + tf.io.gfile.rmtree(source_path) + + +def list_files_by_datetime( + base_path, + start_datetime, + end_datetime=None, + datetime_prefix_format='%Y/%m/%d/%H', + extension='lzo', + parallelism=1, + hour_resolution=1, + sort=False +): + """List files matching `base_path/dt_prefix_format/*.extension` for the requested datetime range. + + Args: + base_path: + The base path. If `None`, returns `None`. + start_datetime: + A `datetime.datetime` or string representing the start of the range (inclusive). + If `None`, it returns `list_files(base_path, extension, sort)`. + end_datetime: + A `datetime.datetime` or string representing the end of the range (inclusive). + If `None`, assumed to be the same as start_datetime. + datetime_prefix_format: + Format compatible with `datetime.datetime.strftime` + (https://docs.python.org/2/library/datetime.html#strftime-and-strptime-behavior). + extension: + The extension of the files composing the dataset (e.g. 'lzo'). + parallelism: + The number of threads used to process list patterns (this is mostly useful + when dealing with filesystems such as HDFS in which listing files is a potentially expensive + operation). + hour_resolution: + The separation between consecutive hours. The default value is 1. + sort: + bool, whether to return a sorted list of files. Default False. + + Returns: + A list with all the matching files. + + Raises: + errors.OpError: If there are filesystem / directory listing errors. + """ + if hour_resolution is None: + hour_resolution = 1 + + if base_path is None: + return None + + if start_datetime is None: + return list_files(base_path, extension, sort) + + # Do this in case people want to use a single day for training. + if end_datetime is None: + end_datetime = start_datetime + + assert parallelism > 0 + assert start_datetime <= end_datetime + + if isinstance(start_datetime, str): + start_datetime = datetime.strptime(start_datetime, datetime_prefix_format) + + if isinstance(end_datetime, str): + end_datetime = datetime.strptime(end_datetime, datetime_prefix_format) + + assert isinstance(start_datetime, datetime) + assert isinstance(end_datetime, datetime) + + base_path = preprocess_path(base_path) + + def _handle_missing_globs(pattern): + try: + return tf.io.gfile.glob(pattern) + except tf.errors.NotFoundError as e: + tf.logging.warning(e.message) + return [] + + # a set is used because there might be some repeated globs depending on dt_prefix_format + globs = { + os.path.join(base_path, dt.strftime(datetime_prefix_format), '*.%s' % extension) + for dt in rrule.rrule( + freq=rrule.HOURLY, interval=hour_resolution, dtstart=start_datetime, until=end_datetime) + } + nested_files = Parallel(n_jobs=parallelism, backend='threading')( + delayed(_handle_missing_globs)(p) for p in globs + ) + flattened_files = list(itertools.chain.from_iterable(nested_files)) + + if not flattened_files: + error_msg = "Files list is empty: base_path={base_path}, start_datetime={start_datetime}, end_datetime={end_datetime}".format( + base_path=base_path, start_datetime=start_datetime, end_datetime=end_datetime + ) + raise OSError(error_msg) + + if sort: + flattened_files = sorted(flattened_files) + + return flattened_files + + +def limit_sparse_tensor_size(sparse_tf, input_size_bits, mask_indices=True): + """ + Returns a ``tf.SparseTensor`` which is the input SparseTensor + limited to the specified input_size_bits + + Args: + sparse_tf: + twml.SparseTensor or tf.SparseTensor + input_size_bits: + The number of bits allocated to the input size. + Input size will be power(2,input_size_bits). + Note that twml.limit_bits truncates any feature keys that + exceed the input size. + mask_indices: + If mask indices is False; only the shape is changed. Defaults to True. + """ + if isinstance(sparse_tf, twml.SparseTensor): + sparse_tf = sparse_tf.to_tf() + if not isinstance(sparse_tf, tf.SparseTensor): + raise TypeError('Input argument `sparse_tf` should either be of type' + 'twml.SparseTensor of tf.SparseTensor. Found type: {}'. + format(type(sparse_tf))) + if mask_indices: + indices = twml.limit_bits(sparse_tf.indices, input_size_bits) + else: + indices = sparse_tf.indices + dense_shape = tf.stack([sparse_tf.dense_shape[0], 1 << input_size_bits]) + return tf.SparseTensor(indices=indices, values=sparse_tf.values, + dense_shape=dense_shape) + + +def create_module_spec(mlp_fn, mode, params, drop_collections=None): + """ + Creates a standard tags_and_args which should be passed to the create_module_spec + spec = hub.create_module_spec(mlp_fn, tags_and_args=tags_and_args). + + Args: + module_fn: + a function to build a graph for the Module. + mode: + mode in which the Estimator is run + params: + parameters passed to the Estimator + """ + import tensorflow_hub as hub # noqa: F402 + tags_and_args = [(set(), {"params": params, "mode": mode}), # serving graph + ({"train"}, {"params": params, "mode": mode}) # training graph + ] + spec = hub.create_module_spec(mlp_fn, tags_and_args=tags_and_args, drop_collections=drop_collections) + return spec + + +def change_name_scope_from_dir(init_scope_name, final_scope_name, save_dir): + """ + Changes the name of the saved scope to the desired name and saves it + to the same save_dir. + + Args: + init_scope_name: + initial scope name + final_scope_name: + desired (final) scope name + save_dir: + directory which the scopes are saved + + In the follwing section we: + - Read all the variables from the latest checkpoint. + - Make a copy of the variables with new name scope. + - Store both sets of variables into the latest checkpoint. + This essentially doubles up the size of the checkpoint. + But when a job is restarted after this part is done, the checkpoint size doubles again. + To avoid doing this, we create a copy in backup if a backup isn't found. + This allows us always read (from backup) and write same sized checkpoint files. + """ + + # Create a backup_checkpoints dir + backup_dir = os.path.join(save_dir, "change_name_scope_backups") + tf.io.gfile.makedirs(backup_dir) + + latest_checkpoint = tf.train.latest_checkpoint(save_dir) + + if latest_checkpoint is None: + raise OSError("No checkpoints found in save_dir: %s" % save_dir) + + latest_backup_checkpoint = tf.train.latest_checkpoint(backup_dir) + + if (latest_backup_checkpoint is None or + (os.path.basename(latest_checkpoint) != + os.path.basename(latest_backup_checkpoint))): + backup_checkpoint(latest_checkpoint, backup_dir, empty_backup=False) + + variables = tf.train.list_variables(backup_dir) + with tf.Graph().as_default(), tf.Session().as_default() as sess: + new_variables = [] + for name, _ in variables: + var = tf.train.load_variable(backup_dir, name) + # Append both the rename and the original variable + new_variables.append( + tf.Variable(var, name=name.replace(init_scope_name, final_scope_name))) + new_variables.append(tf.Variable(var, name=name)) + # Save this to the checkpoint in the save_dir + saver = tf.train.Saver(new_variables) + sess.run(tf.global_variables_initializer()) + saver.save(sess, latest_checkpoint) # pylint: disable=no-member + + +def hub_import(input, module, module_name, trainable=False): + """ + Loads exported hub module. + + Args: + input: + input to hub module + module: + module path + module_name: + signature of the exported hub module + """ + import tensorflow_hub as hub # noqa: F402 + hub_module = hub.Module(module, trainable=trainable) + output = hub_module(input, signature=module_name) + return output + + +def _extract_hash_space_bits(feature_config): + """ + Extract Sparse Shapes for contrib.FeatureConfig. + Arguments: + feature_config: + Feature Configuration of the type contrib.FeatureConfig + Returns: + Dictionary of tensor names and hash space bits. + """ + if not isinstance(feature_config, twml.contrib.feature_config.FeatureConfig): + fc_type = type(feature_config) + raise TypeError(f"Feature config must be of type contrib.FeatureConfig: {fc_type}") + sparse_shapes_dict = {} + for config in feature_config.sparse_extraction_configs: + sparse_shapes_dict[config.output_name] = config.hash_space_bits + return sparse_shapes_dict + + +def fix_shape_sparse(features, feature_config): + """ + Modifies the shape of features which are extracted using the hashing trick. + Features itself is changed by this function. + Arguments: + features: + Feature dictionary extracted by the feature config + feature_config: + Feature Configuration of the type contrib.FeatureConfig + """ + if not isinstance(feature_config, twml.contrib.feature_config.FeatureConfig): + raise TypeError(f"Feature config must be of type contrib.FeatureConfig, currently of {type(feature_config)}") + sparse_shape = _extract_hash_space_bits(feature_config) + if not isinstance(features, dict): + raise TypeError(f"features must be of dictionary type, it is of {type(features)} type") + for key in set(features) & set(sparse_shape): + features[key] = limit_sparse_tensor_size(features[key], sparse_shape[key], mask_indices=False) + + +def touch_file_in_dir(directory, filename): + """ + Creates a file named filename in directory. + + Arguments: + filename: (str) + directory: (str) + """ + file_path = os.path.join(directory, filename) + with tf.io.gfile.GFile(file_path, "w") as f: + f.write("") + + +def file_exist_in_dir(directory: str, filename: str) -> bool: + file_path = os.path.join(directory, filename) + return tf.io.gfile.exists(file_path) + + +def copy_to_local(remote, local, filename, overwrite=False): + """Function to file from remote directory to local directory.""" + assert "hdfs://" not in local + tf.io.gfile.makedirs(local) + return tf.io.gfile.copy( + os.path.join(remote, filename), + os.path.join(local, filename), + overwrite=overwrite, + ) + + +def copy_recursive(src, dst, overwrite=False): + """ + Function to copy a directory recursively. + + Arguments: + src: Source directory. + dst: Destination directory. + overwrite: Specifies if files are to be overwritten if they exist. + """ + + src = src.rstrip("/") + dst = dst.rstrip("/") + + for dirname, subdirs, files in tf.io.gfile.walk(src): + dst_dirname = dirname.replace(src, dst) + tf.io.gfile.makedirs(dst_dirname) + + for f in files: + src_f = os.path.join(dirname, f) + dst_f = os.path.join(dst_dirname, f) + + tf.logging.info(f"Copying {src_f} to {dst_f}") + tf.io.gfile.copy(src_f, dst_f, overwrite=overwrite) + + +def delete_file_or_dir(path): + """ + Delete the file or directory given by `path` + Arguments: + path: + string indicating path of file or directory to remove + """ + if tf.io.gfile.isdir(path): + tf.io.gfile.rmtree(path) + else: + tf.io.gfile.remove(path) + + +def get_distributed_training_job_path(): + """ + Function to get distributed training job path. + Note: distributed training has three jobs, one parameter server job, + one worker job and one evaluator job. All of these three jobs' name + share a common base job name. + """ + job_path = AuroraPath(dc=os.environ.get("TWML_JOB_CLUSTER"), + role=os.environ.get("TWML_JOB_ROLE"), + env=os.environ.get("TWML_JOB_ENV"), + job_name=os.environ.get("TWML_DISTRIBUTED_BASE_JOBNAME")) + return job_path + +def do_every_n_steps(action, num_steps): + """ + Execute a sequence of TensorFlow operations only once in a while. + Specifically, `action` is performed if `global_step` is a + multiple of `num_steps` + + Args: + action: callable to be performed at regular intervals. This callable + must return a TF op with no output tensors. + num_steps: period of performing the action, as measured + in number of training steps + + Returns: + A TensorFlow op with no output tensors, like a tf.print() or tf.no_op(). + You must use tf.control_dependencies() to execute the op. + + """ + global_step = tf.train.get_or_create_global_step() + condition = tf.math.equal(tf.math.floormod(global_step, num_steps), 0) + return tf.cond(condition, action, lambda: tf.no_op()) diff --git a/twml/twml_common/__init__.py b/twml/twml_common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/twml/twml_common/initializer.py b/twml/twml_common/initializer.py new file mode 100644 index 000000000..7a9c734c7 --- /dev/null +++ b/twml/twml_common/initializer.py @@ -0,0 +1,14 @@ +import tensorflow.compat.v1 as tf + + +class PartitionInitializer(tf.keras.initializers.Initializer): + """Required to initialize partitioned weight with numpy array for tests""" + + def __init__(self, np_array): + self.np_array = np_array + + def __call__(self, shape, dtype=None, partition_info=None): + offset = partition_info.var_offset + ix0, ix1 = offset[0], offset[0] + shape[0] + iy0, iy1 = offset[1], offset[1] + shape[1] + return self.np_array[ix0:ix1, iy0:iy1] diff --git a/twml/twml_common/serialize.py b/twml/twml_common/serialize.py new file mode 100644 index 000000000..36c53881e --- /dev/null +++ b/twml/twml_common/serialize.py @@ -0,0 +1,16 @@ +from thrift.protocol import TBinaryProtocol +from thrift.transport import TTransport + + +def serialize(obj): + tbuf = TTransport.TMemoryBuffer() + iproto = TBinaryProtocol.TBinaryProtocol(tbuf) + obj.write(iproto) + return tbuf.getvalue() + + +def deserialize(record, bytes): + tbuf = TTransport.TMemoryBuffer(bytes) + iproto = TBinaryProtocol.TBinaryProtocol(tbuf) + record.read(iproto) + return record diff --git a/twml/twml_common/sparse_inputs.py b/twml/twml_common/sparse_inputs.py new file mode 100644 index 000000000..b8f7939e5 --- /dev/null +++ b/twml/twml_common/sparse_inputs.py @@ -0,0 +1,24 @@ +import numpy as np +import tensorflow.compat.v1 as tf + + +def create_sparse_tensor(batch_size, input_size, num_values, dtype=tf.float32): + random_indices = np.sort(np.random.randint(batch_size * input_size, size=num_values)) + test_indices_i = random_indices // input_size + test_indices_j = random_indices % input_size + test_indices = np.stack([test_indices_i, test_indices_j], axis=1) + test_values = np.random.random(num_values).astype(dtype.as_numpy_dtype) + + return tf.SparseTensor(indices=tf.constant(test_indices), + values=tf.constant(test_values), + dense_shape=(batch_size, input_size)) + + +def create_reference_input(sparse_input, use_binary_values): + if use_binary_values: + sp_a = tf.SparseTensor(indices=sparse_input.indices, + values=tf.ones_like(sparse_input.values), + dense_shape=sparse_input.dense_shape) + else: + sp_a = sparse_input + return sp_a diff --git a/visibilitylib/BUILD b/visibilitylib/BUILD new file mode 100644 index 000000000..76ea1a659 --- /dev/null +++ b/visibilitylib/BUILD @@ -0,0 +1,29 @@ +target( + dependencies = [ + "visibility/lib/src/main/scala/com/twitter/visibility", + ], +) + +target( + name = "conversations", + dependencies = [ + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/common/tweets", + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/conversations", + ], +) + +target( + name = "tweets", + dependencies = [ + "visibility/lib/src/main/scala/com/twitter/visibility/generators", + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/common/tweets", + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/tweets", + ], +) + +target( + name = "users", + dependencies = [ + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/users", + ], +) diff --git a/visibilitylib/README.md b/visibilitylib/README.md new file mode 100644 index 000000000..28c7af03f --- /dev/null +++ b/visibilitylib/README.md @@ -0,0 +1,51 @@ +Overview +======== + +Visibility Filtering is a centralized rule engine that instructs clients how to alter the display of certain Twitter content on read time. The Visibility Filtering library is responsible for filtering Twitter content to support legal compliance, improve product quality, increase user trust, protect revenue through the use of hard-filtering, visible product treatments, and coarse-grained downranking. + +Notice +====== + +Visibility Filtering library is currently being reviewed and rebuilt, and part of the code has been removed and is not ready to be shared yet. The remaining part of the code needs further review and will be shared once it’s ready. Also code comments have been sanitized. + +SafetyLevel +=========== + +Represents the product context in which the Viewer is requesting to view the Content (e.g. Timeline, Profile). + +Features +======== + +Include safety labels and other metadata of a Tweet, flags set on a User (including the Viewer), relationships between Users (e.g. block, follow), User settings, relationships between Users and Content (e.g. reported for spam). + +Action +====== + +The way the Visibility Framework instructs the client to respond to the Viewer’s request for Content, and can include hard filtering (e.g. Drop), soft filtering (e.g. Labels and Interstitials), ranking clues, etc. + +Condition +========= + +Returns a boolean when given map of Features. Conditions can be combined to determine if a Rule should return an Action or the default (Allow). + +Policy +====== + +Rules are expressed as a sequence in priority order to create a Visibility Policy. The library has one policy +per SafetyLevel. + +RuleEngine +=========== + +Evaluates the Action for a Request. + +SafetyLabel +=========== + +A primary labeling mechanism for Safety. A labeled entity associates with tweet, user, Direct Message, media, space etc. Safety labels power different ways of remediations (e.g. applying a SafetyLabel that results in tweet interstitial or notice). + +SafetyLabelType +=============== + +Describes a particular policy violation for a given noun instance, and usually leads to reduced visibility of the +labeled entity in product surfaces. There are many deprecated, and experimental safety label types. Labels with these safety label types have no effect on VF. Additionally, some safety label types are not used, and not designed for VF. diff --git a/visibilitylib/src/main/resources/config/BUILD b/visibilitylib/src/main/resources/config/BUILD new file mode 100644 index 000000000..b45c857ae --- /dev/null +++ b/visibilitylib/src/main/resources/config/BUILD @@ -0,0 +1,6 @@ +resources( + sources = [ + "com/twitter/visibility/*.csv", + "com/twitter/visibility/*.yml", + ], +) diff --git a/visibilitylib/src/main/resources/config/com/twitter/visibility/decider.yml b/visibilitylib/src/main/resources/config/com/twitter/visibility/decider.yml new file mode 100644 index 000000000..c2c8f8a9a --- /dev/null +++ b/visibilitylib/src/main/resources/config/com/twitter/visibility/decider.yml @@ -0,0 +1,906 @@ + +visibility_library_enable_all_subscribed_lists_safety_level: + default_availability: 10000 + +visibility_library_enable_ads_business_settings_safety_level: + default_availability: 10000 + +visibility_library_enable_ads_campaign_safety_level: + default_availability: 10000 + +visibility_library_enable_ads_manager_safety_level: + default_availability: 10000 + +visibility_library_enable_appeals_safety_level: + default_availability: 10000 + +visibility_library_enable_article_tweet_timeline_safety_level: + default_availability: 10000 + +visibility_library_enable_birdwatch_note_author_safety_level: + default_availability: 10000 + +visibility_library_enable_birdwatch_note_tweets_timeline_safety_level: + default_availability: 10000 + +visibility_library_enable_birdwatch_needs_your_help_notifications_safety_level: + default_availability: 10000 + +visibility_library_enable_block_mute_users_timeline_safety_level: + default_availability: 10000 + +visibility_library_enable_brand_safety_safety_level: + default_availability: 10000 + +visibility_library_enable_card_poll_voting_safety_level: + default_availability: 10000 + +visibility_library_enable_cards_service_safety_level: + default_availability: 10000 + +visibility_library_enable_communities_safety_level: + default_availability: 10000 + +visibility_library_enable_conversation_focal_prehydration_safety_level: + default_availability: 10000 + +visibility_library_enable_conversation_focal_tweet_safety_level: + default_availability: 10000 + +visibility_library_enable_conversation_injected_tweet_safety_level: + default_availability: 10000 + +visibility_library_enable_conversation_reply_safety_level: + default_availability: 10000 + +visibility_library_curated_trends_representative_tweet: + default_availability: 10000 + +visibility_library_curation_policy_violations: + default_availability: 10000 + +visibility_library_enable_deprecated_safety_level_safety_level: + default_availability: 10000 + +visibility_library_enable_dev_platform_get_list_tweets_safety_level: + default_availability: 10000 + +visibility_library_enable_des_following_and_followers_user_list_safety_level: + default_availability: 10000 + +visibility_library_enable_des_home_timeline_safety_level: + default_availability: 10000 + +visibility_library_enable_des_quote_tweet_timeline_safety_level: + default_availability: 10000 + +visibility_library_enable_des_realtime_safety_level: + default_availability: 10000 + +visibility_library_enable_des_realtime_spam_enrichment_safety_level: + default_availability: 10000 + +visibility_library_enable_des_realtime_tweet_filter_safety_level: + default_availability: 10000 + +visibility_library_enable_des_retweeting_users_safety_level: + default_availability: 10000 + +visibility_library_enable_des_tweet_detail_safety_level: + default_availability: 10000 + +visibility_library_enable_des_tweet_liking_users_safety_level: + default_availability: 10000 + +visibility_library_enable_des_user_bookmarks_safety_level: + default_availability: 10000 + +visibility_library_enable_des_user_liked_tweets_safety_level: + default_availability: 10000 + +visibility_library_enable_des_user_mentions_safety_level: + default_availability: 10000 + +visibility_library_enable_des_user_tweets_safety_level: + default_availability: 10000 + +visibility_library_enable_dev_platform_compliance_stream_safety_level: + default_availability: 10000 + +visibility_library_enable_direct_messages_safety_level: + default_availability: 10000 + +visibility_library_enable_direct_messages_conversation_list_safety_level: + default_availability: 10000 + +visibility_library_enable_direct_messages_conversation_timeline_safety_level: + default_availability: 10000 + +visibility_library_enable_direct_messages_inbox_safety_level: + default_availability: 10000 + +visibility_library_enable_direct_messages_muted_users_safety_level: + default_availability: 10000 + +visibility_library_enable_direct_messages_pinned_safety_level: + default_availability: 10000 + +visibility_library_enable_elevated_quote_tweet_timeline_safety_level: + default_availability: 10000 + +visibility_library_enable_direct_messages_search_safety_level: + default_availability: 10000 + +visibility_library_enable_embedded_tweet_safety_level: + default_availability: 10000 + +visibility_library_enable_embeds_public_interest_notice_safety_level: + default_availability: 10000 + +visibility_library_enable_embed_tweet_markup_safety_level: + default_availability: 10000 + +visibility_library_enable_write_path_limited_actions_enforcement_safety_level: + default_availability: 10000 + +visibility_library_enable_filter_all_safety_level: + default_availability: 10000 + +visibility_library_enable_filter_all_placeholder_safety_level: + default_availability: 10000 + +visibility_library_enable_filter_default_safety_level: + default_availability: 10000 + +visibility_library_enable_filter_none_safety_level: + default_availability: 10000 + +visibility_library_enable_followed_topics_timeline_safety_level: + default_availability: 10000 + +visibility_library_enable_follower_connections_safety_level: + default_availability: 10000 + +visibility_library_enable_following_and_followers_user_list_safety_level: + default_availability: 10000 + +visibility_library_enable_for_development_only_safety_level: + default_availability: 10000 + +visibility_library_enable_friends_following_list_safety_level: + default_availability: 10000 + +visibility_library_enable_graphql_default_safety_level: + default_availability: 10000 + +visibility_library_enable_gryphon_decks_and_columns_safety_level: + default_availability: 10000 + +visibility_library_enable_humanization_nudge_safety_level: + default_availability: 10000 + +visibility_library_enable_kitchen_sink_development_safety_level: + default_availability: 10000 + +visibility_library_enable_list_header_safety_level: + default_availability: 10000 + +visibility_library_enable_list_memberships_safety_level: + default_availability: 10000 + +visibility_library_enable_list_ownerships_safety_level: + default_availability: 10000 + +visibility_library_enable_list_recommendations_safety_level: + default_availability: 10000 + +visibility_library_enable_list_search_safety_level: + default_availability: 10000 + +visibility_library_enable_list_subscriptions_safety_level: + default_availability: 10000 + +visibility_library_enable_live_pipeline_engagement_counts_safety_level: + default_availability: 10000 + +visibility_library_enable_live_video_timeline_safety_level: + default_availability: 10000 + +visibility_library_enable_magic_recs_safety_level: + default_availability: 10000 + +visibility_library_enable_magic_recs_aggressive_safety_level: + default_availability: 10000 + +visibility_library_enable_magic_recs_aggressive_v2_safety_level: + default_availability: 10000 + +visibility_library_enable_magic_recs_v2_safety_level: + default_availability: 10000 + +visibility_library_enable_minimal_safety_level: + default_availability: 10000 + +visibility_library_enable_moderated_tweets_timeline_safety_level: + default_availability: 10000 + +visibility_library_enable_moments_safety_level: + default_availability: 10000 + +visibility_library_enable_nearby_timeline_safety_level: + default_availability: 10000 + +visibility_library_enable_new_user_experience_safety_level: + default_availability: 10000 + +visibility_library_enable_notifications_ibis_safety_level: + default_availability: 10000 + +visibility_library_enable_notifications_platform_safety_level: + default_availability: 10000 + +visibility_library_enable_notifications_platform_push_safety_level: + default_availability: 10000 + +visibility_library_enable_notifications_read_safety_level: + default_availability: 10000 + +visibility_library_enable_notifications_timeline_device_follow_safety_level: + default_availability: 10000 + +visibility_library_enable_notifications_write_safety_level: + default_availability: 10000 + +visibility_library_enable_notification_writer_v2_safety_level: + default_availability: 10000 + +visibility_library_enable_notifications_writer_tweet_hydrator_safety_level: + default_availability: 10000 + +visibility_library_enable_quick_promote_tweet_eligibility_safety_level: + default_availability: 10000 + +visibility_library_enable_quote_tweet_timeline_safety_level: + default_availability: 10000 + +visibility_library_enable_quoted_tweet_rules_safety_level: + default_availability: 10000 + +visibility_library_enable_recommendations_safety_level: + default_availability: 10000 + +visibility_library_enable_recos_video_safety_level: + default_availability: 10000 + +visibility_library_enable_recos_write_path_safety_level: + default_availability: 10000 + +visibility_library_enable_replies_grouping_safety_level: + default_availability: 10000 + +visibility_library_enable_report_center_safety_level: + default_availability: 10000 + +visibility_library_enable_returning_user_experience_safety_level: + default_availability: 10000 + +visibility_library_enable_returning_user_experience_focal_tweet_safety_level: + default_availability: 10000 + +visibility_library_enable_revenue_safety_level: + default_availability: 10000 + +visibility_library_enable_rito_actioned_tweet_timeline_safety_level: + default_availability: 10000 + +visibility_library_enable_safe_search_minimal_safety_level: + default_availability: 10000 + +visibility_library_enable_safe_search_strict_safety_level: + default_availability: 10000 + +visibility_library_enable_search_hydration_safety_level: + default_availability: 10000 + +visibility_library_enable_search_latest_safety_level: + default_availability: 10000 + +visibility_library_enable_search_mixer_srp_minimal_safety_level: + default_availability: 10000 + +visibility_library_enable_search_mixer_srp_strict_safety_level: + default_availability: 10000 + +visibility_library_enable_user_search_srp_safety_level: + default_availability: 10000 + +visibility_library_enable_user_search_typeahead_safety_level: + default_availability: 10000 + +visibility_library_enable_search_people_srp_safety_level: + default_availability: 10000 + +visibility_library_enable_search_people_typeahead_safety_level: + default_availability: 10000 + +visibility_library_enable_search_photo_safety_level: + default_availability: 10000 + +visibility_library_enable_search_top_safety_level: + default_availability: 10000 + +visibility_library_enable_search_trend_takeover_promoted_tweet_safety_level: + default_availability: 10000 + +visibility_library_enable_search_video_safety_level: + default_availability: 10000 + +visibility_library_enable_search_latest_user_rules_safety_level: + default_availability: 10000 + +visibility_library_enable_shopping_manager_spy_mode_safety_level: + default_availability: 10000 + +visibility_library_enable_signals_reactions_safety_level: + default_availability: 10000 + +visibility_library_enable_signals_tweet_reacting_users_safety_level: + default_availability: 10000 + +visibility_library_enable_social_proof_safety_level: + default_availability: 10000 + +visibility_library_enable_soft_intervention_pivot_safety_level: + default_availability: 10000 + +visibility_library_enable_space_fleetline_safety_level: + default_availability: 10000 + +visibility_library_enable_space_home_timeline_upranking_safety_level: + default_availability: 10000 + +visibility_library_enable_space_join_screen_safety_level: + default_availability: 10000 + +visibility_library_enable_space_notifications_safety_level: + default_availability: 10000 + +visibility_library_enable_spaces_safety_level: + default_availability: 10000 + +visibility_library_enable_spaces_participants_safety_level: + default_availability: 0 + +visibility_library_enable_spaces_seller_application_status_safety_level: + default_availability: 10000 + +visibility_library_enable_spaces_sharing_safety_level: + default_availability: 10000 + +visibility_library_enable_space_tweet_avatar_home_timeline_safety_level: + default_availability: 10000 + +visibility_library_enable_stickers_timeline_safety_level: + default_availability: 10000 + +visibility_library_enable_strato_ext_limited_engagements_safety_level: + default_availability: 10000 + +visibility_library_enable_stream_services_safety_level: + default_availability: 10000 + +visibility_library_enable_test_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_bookmark_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_content_controls_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_conversations_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_conversations_downranking_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_conversations_downranking_minimal_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_favorites_safety_level: + default_availability: 10000 + +visibility_library_enable_self_view_timeline_favorites_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_focal_tweet_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_following_activity_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_home_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_home_communities_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_home_hydration_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_home_latest_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_home_recommendations_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_home_topic_follow_recommendations_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_scorer_safety_level: + default_availability: 10000 + +visibility_library_enable_topics_landing_page_topic_recommendations_safety_level: + default_availability: 10000 + +visibility_library_enable_explore_recommendations_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_moderated_tweets_hydration_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_injection_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_liked_by_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_lists_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_media_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_mentions_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_profile_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_profile_all_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_profile_spaces_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_profile_super_follows_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_reactive_blending_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_retweeted_by_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_super_liked_by_safety_level: + default_availability: 10000 + +visibility_library_enable_tombstoning_safety_level: + default_availability: 10000 + +visibility_library_enable_trends_representative_tweet_safety_level: + default_availability: 10000 + +visibility_library_enable_trusted_friends_user_list_safety_level: + default_availability: 10000 + +visibility_library_enable_tweet_detail_safety_level: + default_availability: 10000 + +visibility_library_enable_tweet_detail_non_too_safety_level: + default_availability: 10000 + +visibility_library_enable_tweet_detail_with_injections_hydration_safety_level: + default_availability: 10000 + +visibility_library_enable_tweet_engagers_safety_level: + default_availability: 10000 + +visibility_library_enable_tweet_reply_nudge_safety_level: + default_availability: 10000 + +visibility_library_enable_tweet_scoped_timeline_safety_level: + default_availability: 10000 + +visibility_library_enable_tweet_writes_api_safety_level: + default_availability: 10000 + +visibility_library_enable_twitter_article_compose_safety_level: + default_availability: 10000 + +visibility_library_enable_twitter_article_profile_tab_safety_level: + default_availability: 10000 + +visibility_library_enable_twitter_article_read_safety_level: + default_availability: 10000 + +visibility_library_enable_user_profile_header_safety_level: + default_availability: 10000 + +visibility_library_enable_user_milestone_recommendation_safety_level: + default_availability: 10000 + +visibility_library_enable_user_scoped_timeline_safety_level: + default_availability: 10000 + +visibility_library_enable_user_settings_safety_level: + default_availability: 10000 + +visibility_library_enable_video_ads_safety_level: + default_availability: 10000 + +visibility_library_enable_timeline_home_promoted_hydration_safety_level: + default_availability: 10000 + +visibility_library_enable_super_follower_connnections_safety_level: + default_availability: 10000 + +visibility_library_enable_super_like_safety_level: + default_availability: 10000 + +visibility_library_enable_topic_recommendations_safety_level: + default_availability: 10000 + +visibility_library_enable_ads_reporting_dashboard_safety_level: + default_availability: 10000 + +visibility_library_enable_search_top_qig_safety_level: + default_availability: 10000 + +visibility_library_enable_content_control_tool_install_safety_level: + default_availability: 10000 + +visibility_library_enable_conversation_control_rules: + default_availability: 10000 + +visibility_library_enable_community_tweets_rules: + default_availability: 10000 + +visibility_library_enable_drop_community_tweet_with_undefined_community_rule: + default_availability: 10000 + +visibility_library_enable_p_spammy_tweet_downrank_convos_low_quality: + default_availability: 10000 + +visibility_library_enable_high_p_spammy_tweet_score_search_tweet_label_drop_rule: + default_availability: 10000 + +visibility_library_enable_rito_actioned_tweet_downrank_convos_low_quality: + default_availability: 10000 + +visibility_library_enable_toxic_reply_filter_conversation: + default_availability: 10000 + +visibility_library_enable_toxic_reply_filter_notifications: + default_availability: 10000 + +visibility_library_enable_new_sensitive_media_settings_interstitial_rules_home_timeline: + default_availability: 10000 + +visibility_library_enable_legacy_sensitive_media_rules_home_timeline: + default_availability: 10000 + +visibility_library_enable_new_sensitive_media_settings_interstitial_rules_conversation: + default_availability: 10000 + +visibility_library_enable_legacy_sensitive_media_rules_conversation: + default_availability: 10000 + +visibility_library_enable_new_sensitive_media_settings_interstitials_rules_profile_timeline: + default_availability: 10000 + +visibility_library_enable_legacy_sensitive_media_rules_profile_timeline: + default_availability: 10000 + +visibility_library_enable_new_sensitive_media_settings_interstitials_rules_tweet_detail: + default_availability: 10000 + +visibility_library_enable_legacy_sensitive_media_rules_tweet_detail: + default_availability: 10000 + +visibility_library_enable_legacy_sensitive_media_rules_direct_messages: + default_availability: 10000 + +visibility_library_enable_smyte_spam_tweet_rule: + default_availability: 10000 + +visibility_library_enable_high_spammy_tweet_content_score_search_latest_tweet_label_drop_rule: + default_availability: 10000 + +visibility_library_enable_high_spammy_tweet_content_score_search_top_tweet_label_drop_rule: + default_availability: 10000 + +visibility_library_enable_high_spammy_tweet_content_score_convo_downrank_abusive_quality_rule: + default_availability: 10000 + +visibility_library_enable_high_cryptospam_score_convo_downrank_abusive_quality_rule: + default_availability: 10000 + +visibility_library_enable_high_spammy_tweet_content_score_trends_top_tweet_label_drop_rule: + default_availability: 10000 + +visibility_library_enable_high_spammy_tweet_content_score_trends_latest_tweet_label_drop_rule: + default_availability: 10000 + +visibility_library_enable_gore_and_violence_topic_high_recall_tweet_label_rule: + default_availability: 10000 + +visibility_library_enable_limit_replies_followers_conversation_rule: + default_availability: 10000 + +visibility_library_enable_blink_bad_downranking_rule: + default_availability: 10000 + +visibility_library_enable_blink_worst_downranking_rule: + default_availability: 10000 + +visibility_library_enable_copypasta_spam_downrank_convos_abusive_quality_rule: + default_availability: 10000 + +visibility_library_enable_copypasta_spam_search_drop_rule: + default_availability: 10000 + +visibility_library_enable_spammy_user_model_high_precision_drop_tweet_rule: + default_availability: 10000 + +visibility_library_enable_avoid_nsfw_rules: + default_availability: 10000 + +visibility_library_enable_reported_tweet_interstitial_rule: + default_availability: 10000 + +visibility_library_enable_reported_tweet_interstitial_search_rule: + default_availability: 10000 + +visibility_library_enable_drop_exclusive_tweet_content_rule: + default_availability: 10000 + +visibility_library_enable_drop_exclusive_tweet_content_rule_fail_closed: + default_availability: 10000 + +visibility_library_enable_drop_all_exclusive_tweets_rule: + default_availability: 10000 + +visibility_library_enable_drop_all_exclusive_tweets_rule_fail_closed: + default_availability: 10000 + +visibility_library_enable_tombstone_exclusive_quoted_tweet_content_rule: + default_availability: 10000 + +visibility_library_enable_downrank_spam_reply_sectioning_rule: + default_availability: 10000 + +visibility_library_enable_nsfw_text_sectioning_rule: + default_availability: 10000 + +visibility_library_enable_search_ipi_safe_search_without_user_in_query_drop_rule: + default_availability: 10000 + +visibility_library_enable_timeline_home_promoted_tweet_health_enforcement_rules: + default_availability: 10000 + +visibility_library_enable_muted_keyword_filtering_space_title_notifications_rule: + default_availability: 10000 + +visibility_library_enable_drop_tweets_with_georestricted_media_rule: + default_availability: 10000 + +visibility_library_enable_drop_all_trusted_friends_tweets_rule: + default_availability: 10000 + +visibility_library_enable_drop_all_trusted_friends_tweet_content_rule: + default_availability: 10000 + +visibility_library_enable_drop_all_collab_invitation_tweets_rule: + default_availability: 10000 + +visibility_library_enable_fetch_tweet_reported_perspective: + default_availability: 10000 + +visibility_library_enable_fetch_tweet_media_metadata: + default_availability: 10000 + +visibility_library_enable_follow_check_in_mutedkeyword: + default_availability: 10000 + +visibility_library_enable_media_interstitial_composition: + default_availability: 10000 + +visibility_library_enable_verdict_scribing_from_tweet_visibility_library: + default_availability: 0 + +visibility_library_enable_verdict_logger_event_publisher_instantiation_from_tweet_visibility_library: + default_availability: 10000 + +visibility_library_enable_verdict_scribing_from_timeline_conversations_visibility_library: + default_availability: 0 + +visibility_library_enable_verdict_logger_event_publisher_instantiation_from_timeline_conversations_visibility_library: + default_availability: 10000 + +visibility_library_enable_verdict_scribing_from_blender_visibility_library: + default_availability: 0 + +visibility_library_enable_verdict_logger_event_publisher_instantiation_from_blender_visibility_library: + default_availability: 10000 + +visibility_library_enable_verdict_scribing_from_search_visibility_library: + default_availability: 0 + +visibility_library_enable_verdict_logger_event_publisher_instantiation_from_search_visibility_library: + default_availability: 0 + +visibility_library_enable_localized_tombstones_on_visibility_results: + default_availability: 10000 + +visibility_library_enable_short_circuiting_from_tweet_visibility_library: + default_availability: 10000 + +visibility_library_enable_card_visibility_library_card_uri_parsing: + default_availability: 10000 + +visibility_library_enable_short_circuiting_from_timeline_conversations_visibility_library: + default_availability: 10000 + +visibility_library_enable_short_circuiting_from_blender_visibility_library: + default_availability: 10000 + +visibility_library_enable_short_circuiting_from_search_visibility_library: + default_availability: 0 + +visibility_library_enable_nsfw_text_topics_drop_rule: + default_availability: 10000 + +visibility_library_enable_spammy_tweet_rule_verdict_logging: + default_availability: 0 + +visibility_library_enable_downlevel_rule_verdict_logging: + default_availability: 0 + +visibility_library_enable_likely_likely_ivs_user_label_drop_rule: + default_availability: 10000 + +visibility_library_enable_card_uri_root_domain_deny_list_rule: + default_availability: 10000 + +visibility_library_enable_community_non_member_poll_card_rule: + default_availability: 10000 + +visibility_library_enable_community_non_member_poll_card_rule_fail_closed: + default_availability: 10000 + +visibility_library_enable_experimental_nudge_label_rule: + default_availability: 10000 + +visibility_library_enable_user_self_view_only_safety_level: + default_availability: 10000 + +visibility_library_nsfw_high_precision_user_label_avoid_tweet_rule_enabled: + default_availability: 10000 + +visibility_library_enable_new_ad_avoidance_rules: + default_availability: 10000 + +visibility_library_enable_nsfa_high_recall_ad_avoidance_rules: + default_availability: 0 + +visibility_library_enable_nsfa_keywords_high_precision_ad_avoidance_rules: + default_availability: 0 + +visibility_library_enable_stale_tweet_drop_rule: + default_availability: 10000 + +visibility_library_enable_stale_tweet_drop_rule_fail_closed: + default_availability: 10000 + +visibility_library_enable_edit_history_timeline_safety_level: + default_availability: 10000 + +visibility_library_enable_delete_state_tweet_rules: + default_availability: 10000 + +visibility_library_enable_spaces_sharing_nsfw_drop_rule: + default_availability: 10000 + +visibility_library_enable_viewer_is_soft_user_drop_rule: + default_availability: 10000 + +visibility_library_enable_backend_limited_actions: + default_availability: 10000 + +visibility_library_enable_base_qig_safety_level: + default_availability: 10000 + +visibility_library_enable_notifications_qig_safety_level: + default_availability: 10000 + +visibility_library_enable_access_internal_promoted_content_safety_level: + default_availability: 10000 + +visibility_library_enable_pdna_quoted_tweet_tombstone_rule: + default_availability: 10000 + +visibility_library_enable_spam_quoted_tweet_tombstone_rule: + default_availability: 10000 + +visibility_library_enable_nsfw_hp_quoted_tweet_drop_experiment_rule: + default_availability: 10000 + +visibility_library_enable_nsfw_hp_quoted_tweet_tombstone_experiment_rule: + default_availability: 10000 + +visibility_library_enable_inner_quoted_tweet_viewer_blocks_author_interstitial_rule: + default_availability: 10000 + +visibility_library_enable_inner_quoted_tweet_viewer_mutes_author_interstitial_rule: + default_availability: 10000 + +visibility_library_enable_experimental_rule_engine: + default_availability: 10000 + +visibility_library_enable_fosnr_rules: + default_availability: 0 + +visibility_library_enable_localized_interstitial_generator: + default_availability: 10000 + +visibility_library_convos_enable_legacy_interstitial: + default_availability: 10000 + +visibility_library_convos_enable_localized_interstitial: + default_availability: 10000 + +visibility_library_enable_profile_mixer_media_safety_level: + default_availability: 10000 + +visibility_library_enable_profile_mixer_favorites_safety_level: + default_availability: 10000 + +visibility_library_enable_zipbird_consumer_archives_safety_level: + default_availability: 10000 + +visibility_library_enable_tweet_award_safety_level: + default_availability: 10000 + +visibility_library_disable_legacy_interstitial_filtered_reason: + default_availability: 10000 + +visibility_library_enable_search_basic_block_mute_rules: + default_availability: 10000 + +visibility_library_enable_localized_interstitial_user_state_lib: + default_availability: 10000 + +visibility_library_enable_abusive_behavior_drop_rule: + default_availability: 10000 + +visibility_library_enable_abusive_behavior_interstitial_rule: + default_availability: 10000 + +visibility_library_enable_abusive_behavior_limited_engagements_rule: + default_availability: 10000 + +visibility_library_enable_not_graduated_downrank_convos_abusive_quality_rule: + default_availability: 0 + +visibility_library_enable_not_graduated_search_drop_rule: + default_availability: 0 + +visibility_library_enable_not_graduated_drop_rule: + default_availability: 0 + +visibility_library_enable_memoize_safety_level_params: + default_availability: 0 + +visibility_library_enable_author_blocks_viewer_drop_rule: + default_availability: 0 diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/BUILD new file mode 100644 index 000000000..47501c7d4 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/BUILD @@ -0,0 +1,42 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "abdecider/src/main/scala", + "configapi/configapi-core", + "decider/src/main/scala", + "featureswitches/featureswitches-core/src/main/scala", + "servo/decider/src/main/scala", + "servo/util/src/main/scala", + "stitch/stitch-core", + "util/util-logging/src/main/scala", + "util/util-stats/src/main/scala", + "visibility/common/src/main/scala/com/twitter/visibility/common/actions", + "visibility/common/src/main/scala/com/twitter/visibility/common/stitch", + "visibility/lib/src/main/resources/config", + "visibility/lib/src/main/scala/com/twitter/visibility/builder", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/configs", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/params", + "visibility/lib/src/main/scala/com/twitter/visibility/engine", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/generators", + "visibility/lib/src/main/scala/com/twitter/visibility/models", + "visibility/lib/src/main/scala/com/twitter/visibility/rules", + "visibility/lib/src/main/scala/com/twitter/visibility/rules/generators", + "visibility/lib/src/main/scala/com/twitter/visibility/rules/providers", + "visibility/lib/src/main/scala/com/twitter/visibility/util", + ], + exports = [ + "configapi/configapi-core", + "visibility/common/src/main/scala/com/twitter/visibility/common/actions", + "visibility/lib/src/main/scala/com/twitter/visibility/builder", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/configs", + "visibility/lib/src/main/scala/com/twitter/visibility/models", + "visibility/lib/src/main/scala/com/twitter/visibility/rules", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/VisibilityLibrary.scala b/visibilitylib/src/main/scala/com/twitter/visibility/VisibilityLibrary.scala new file mode 100644 index 000000000..1e89c8818 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/VisibilityLibrary.scala @@ -0,0 +1,387 @@ +package com.twitter.visibility + +import com.twitter.abdecider.LoggingABDecider +import com.twitter.abdecider.NullABDecider +import com.twitter.decider.Decider +import com.twitter.decider.NullDecider +import com.twitter.featureswitches.v2.FeatureSwitches +import com.twitter.featureswitches.v2.NullFeatureSwitches +import com.twitter.finagle.stats.NullStatsReceiver +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.logging.Logger +import com.twitter.logging.NullLogger +import com.twitter.servo.util.Gate +import com.twitter.servo.util.MemoizingStatsReceiver +import com.twitter.stitch.Stitch +import com.twitter.timelines.configapi.Params +import com.twitter.util.Try +import com.twitter.visibility.builder._ +import com.twitter.visibility.common.stitch.StitchHelpers +import com.twitter.visibility.configapi.VisibilityParams +import com.twitter.visibility.configapi.configs.VisibilityDeciderGates +import com.twitter.visibility.engine.DeciderableVisibilityRuleEngine +import com.twitter.visibility.engine.VisibilityResultsMetricRecorder +import com.twitter.visibility.engine.VisibilityRuleEngine +import com.twitter.visibility.engine.VisibilityRulePreprocessor +import com.twitter.visibility.features.FeatureMap +import com.twitter.visibility.models.ContentId +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.ViewerContext +import com.twitter.visibility.rules.EvaluationContext +import com.twitter.visibility.rules.Rule +import com.twitter.visibility.rules.generators.TweetRuleGenerator +import com.twitter.visibility.rules.providers.InjectedPolicyProvider +import com.twitter.visibility.util.DeciderUtil +import com.twitter.visibility.util.FeatureSwitchUtil +import com.twitter.visibility.util.LoggingUtil + +object VisibilityLibrary { + + object Builder { + + def apply(log: Logger, statsReceiver: StatsReceiver): Builder = new Builder( + log, + new MemoizingStatsReceiver(statsReceiver) + ) + } + + case class Builder( + log: Logger, + statsReceiver: StatsReceiver, + decider: Option[Decider] = None, + abDecider: Option[LoggingABDecider] = None, + featureSwitches: Option[FeatureSwitches] = None, + enableStitchProfiling: Gate[Unit] = Gate.False, + captureDebugStats: Gate[Unit] = Gate.False, + enableComposableActions: Gate[Unit] = Gate.False, + enableFailClosed: Gate[Unit] = Gate.False, + enableShortCircuiting: Gate[Unit] = Gate.True, + memoizeSafetyLevelParams: Gate[Unit] = Gate.False) { + + def withDecider(decider: Decider): Builder = copy(decider = Some(decider)) + + @deprecated("use .withDecider and pass in a decider that is properly configured per DC") + def withDefaultDecider(isLocal: Boolean, useLocalOverrides: Boolean = false): Builder = { + if (isLocal) { + withLocalDecider + } else { + withDecider( + DeciderUtil.mkDecider( + useLocalDeciderOverrides = useLocalOverrides, + )) + } + } + + def withLocalDecider(): Builder = withDecider(DeciderUtil.mkLocalDecider) + + def withNullDecider(): Builder = + withDecider(new NullDecider(isAvail = true, availabilityDefined = true)) + + def withABDecider(abDecider: LoggingABDecider, featureSwitches: FeatureSwitches): Builder = + abDecider match { + case abd: NullABDecider => + copy(abDecider = Some(abd), featureSwitches = Some(NullFeatureSwitches)) + case _ => + copy( + abDecider = Some(abDecider), + featureSwitches = Some(featureSwitches) + ) + } + + def withABDecider(abDecider: LoggingABDecider): Builder = abDecider match { + case abd: NullABDecider => + withABDecider(abDecider = abd, featureSwitches = NullFeatureSwitches) + case _ => + withABDecider( + abDecider = abDecider, + featureSwitches = + FeatureSwitchUtil.mkVisibilityLibraryFeatureSwitches(abDecider, statsReceiver) + ) + } + + def withClientEventsLogger(clientEventsLogger: Logger): Builder = + withABDecider(DeciderUtil.mkABDecider(Some(clientEventsLogger))) + + def withDefaultABDecider(isLocal: Boolean): Builder = + if (isLocal) { + withABDecider(NullABDecider) + } else { + withClientEventsLogger(LoggingUtil.mkDefaultLogger(statsReceiver)) + } + + def withNullABDecider(): Builder = withABDecider(NullABDecider) + + def withEnableStitchProfiling(gate: Gate[Unit]): Builder = + copy(enableStitchProfiling = gate) + + def withCaptureDebugStats(gate: Gate[Unit]): Builder = + copy(captureDebugStats = gate) + + def withEnableComposableActions(gate: Gate[Unit]): Builder = + copy(enableComposableActions = gate) + + def withEnableComposableActions(gateBoolean: Boolean): Builder = { + val gate = Gate.const(gateBoolean) + copy(enableComposableActions = gate) + } + + def withEnableFailClosed(gate: Gate[Unit]): Builder = + copy(enableFailClosed = gate) + + def withEnableFailClosed(gateBoolean: Boolean): Builder = { + val gate = Gate.const(gateBoolean) + copy(enableFailClosed = gate) + } + + def withEnableShortCircuiting(gate: Gate[Unit]): Builder = + copy(enableShortCircuiting = gate) + + def withEnableShortCircuiting(gateBoolean: Boolean): Builder = { + val gate = Gate.const(gateBoolean) + copy(enableShortCircuiting = gate) + } + + def memoizeSafetyLevelParams(gate: Gate[Unit]): Builder = + copy(memoizeSafetyLevelParams = gate) + + def memoizeSafetyLevelParams(gateBoolean: Boolean): Builder = { + val gate = Gate.const(gateBoolean) + copy(memoizeSafetyLevelParams = gate) + } + + def build(): VisibilityLibrary = { + + (decider, abDecider, featureSwitches) match { + case (None, _, _) => + throw new IllegalStateException( + "Decider is unset! If intentional, please call .withNullDecider()." + ) + + case (_, None, _) => + throw new IllegalStateException( + "ABDecider is unset! If intentional, please call .withNullABDecider()." + ) + + case (_, _, None) => + throw new IllegalStateException( + "FeatureSwitches is unset! This is a bug." + ) + + case (Some(d), Some(abd), Some(fs)) => + new VisibilityLibrary( + statsReceiver, + d, + abd, + VisibilityParams(log, statsReceiver, d, abd, fs), + enableStitchProfiling = enableStitchProfiling, + captureDebugStats = captureDebugStats, + enableComposableActions = enableComposableActions, + enableFailClosed = enableFailClosed, + enableShortCircuiting = enableShortCircuiting, + memoizeSafetyLevelParams = memoizeSafetyLevelParams) + } + } + } + + val nullDecider = new NullDecider(true, true) + + lazy val NullLibrary: VisibilityLibrary = new VisibilityLibrary( + NullStatsReceiver, + nullDecider, + NullABDecider, + VisibilityParams( + NullLogger, + NullStatsReceiver, + nullDecider, + NullABDecider, + NullFeatureSwitches), + enableStitchProfiling = Gate.False, + captureDebugStats = Gate.False, + enableComposableActions = Gate.False, + enableFailClosed = Gate.False, + enableShortCircuiting = Gate.True, + memoizeSafetyLevelParams = Gate.False + ) +} + +class VisibilityLibrary private[VisibilityLibrary] ( + baseStatsReceiver: StatsReceiver, + decider: Decider, + abDecider: LoggingABDecider, + visibilityParams: VisibilityParams, + enableStitchProfiling: Gate[Unit], + captureDebugStats: Gate[Unit], + enableComposableActions: Gate[Unit], + enableFailClosed: Gate[Unit], + enableShortCircuiting: Gate[Unit], + memoizeSafetyLevelParams: Gate[Unit]) { + + val statsReceiver: StatsReceiver = + new MemoizingStatsReceiver(baseStatsReceiver.scope("visibility_library")) + + val metricsRecorder = VisibilityResultsMetricRecorder(statsReceiver, captureDebugStats) + + val visParams: VisibilityParams = visibilityParams + + val visibilityDeciderGates = VisibilityDeciderGates(decider) + + val profileStats: MemoizingStatsReceiver = new MemoizingStatsReceiver( + statsReceiver.scope("profiling")) + + val perSafetyLevelProfileStats: StatsReceiver = profileStats.scope("for_safety_level") + + val featureMapBuilder: FeatureMapBuilder.Build = + FeatureMapBuilder(statsReceiver, enableStitchProfiling) + + private lazy val tweetRuleGenerator = new TweetRuleGenerator() + lazy val policyProvider = new InjectedPolicyProvider( + visibilityDeciderGates = visibilityDeciderGates, + tweetRuleGenerator = tweetRuleGenerator) + + val candidateVisibilityRulePreprocessor: VisibilityRulePreprocessor = VisibilityRulePreprocessor( + metricsRecorder, + policyProviderOpt = Some(policyProvider) + ) + + val fallbackVisibilityRulePreprocessor: VisibilityRulePreprocessor = VisibilityRulePreprocessor( + metricsRecorder) + + lazy val candidateVisibilityRuleEngine: VisibilityRuleEngine = VisibilityRuleEngine( + Some(candidateVisibilityRulePreprocessor), + metricsRecorder, + enableComposableActions, + enableFailClosed, + policyProviderOpt = Some(policyProvider) + ) + + lazy val fallbackVisibilityRuleEngine: VisibilityRuleEngine = VisibilityRuleEngine( + Some(fallbackVisibilityRulePreprocessor), + metricsRecorder, + enableComposableActions, + enableFailClosed) + + val ruleEngineVersionStatsReceiver = statsReceiver.scope("rule_engine_version") + def isReleaseCandidateEnabled: Boolean = visibilityDeciderGates.enableExperimentalRuleEngine() + + private def visibilityRuleEngine: DeciderableVisibilityRuleEngine = { + if (isReleaseCandidateEnabled) { + ruleEngineVersionStatsReceiver.counter("release_candidate").incr() + candidateVisibilityRuleEngine + } else { + ruleEngineVersionStatsReceiver.counter("fallback").incr() + fallbackVisibilityRuleEngine + } + } + + private def profileStitch[A](result: Stitch[A], safetyLevelName: String): Stitch[A] = + if (enableStitchProfiling()) { + StitchHelpers.profileStitch( + result, + Seq(profileStats, perSafetyLevelProfileStats.scope(safetyLevelName)) + ) + } else { + result + } + + def getParams(viewerContext: ViewerContext, safetyLevel: SafetyLevel): Params = { + if (memoizeSafetyLevelParams()) { + visibilityParams.memoized(viewerContext, safetyLevel) + } else { + visibilityParams(viewerContext, safetyLevel) + } + } + + def evaluationContextBuilder(viewerContext: ViewerContext): EvaluationContext.Builder = + EvaluationContext + .Builder(statsReceiver, visibilityParams, viewerContext) + .withMemoizedParams(memoizeSafetyLevelParams) + + def runRuleEngine( + contentId: ContentId, + featureMap: FeatureMap, + evaluationContextBuilder: EvaluationContext.Builder, + safetyLevel: SafetyLevel + ): Stitch[VisibilityResult] = + profileStitch( + visibilityRuleEngine( + evaluationContextBuilder.build(safetyLevel), + safetyLevel, + new VisibilityResultBuilder(contentId, featureMap), + enableShortCircuiting + ), + safetyLevel.name + ) + + def runRuleEngine( + contentId: ContentId, + featureMap: FeatureMap, + viewerContext: ViewerContext, + safetyLevel: SafetyLevel + ): Stitch[VisibilityResult] = + profileStitch( + visibilityRuleEngine( + EvaluationContext(safetyLevel, getParams(viewerContext, safetyLevel), statsReceiver), + safetyLevel, + new VisibilityResultBuilder(contentId, featureMap), + enableShortCircuiting + ), + safetyLevel.name + ) + + def runRuleEngine( + viewerContext: ViewerContext, + safetyLevel: SafetyLevel, + preprocessedResultBuilder: VisibilityResultBuilder, + preprocessedRules: Seq[Rule] + ): Stitch[VisibilityResult] = + profileStitch( + visibilityRuleEngine( + EvaluationContext(safetyLevel, getParams(viewerContext, safetyLevel), statsReceiver), + safetyLevel, + preprocessedResultBuilder, + enableShortCircuiting, + Some(preprocessedRules) + ), + safetyLevel.name + ) + + def runRuleEngineBatch( + contentIds: Seq[ContentId], + featureMapProvider: (ContentId, SafetyLevel) => FeatureMap, + viewerContext: ViewerContext, + safetyLevel: SafetyLevel, + ): Stitch[Seq[Try[VisibilityResult]]] = { + val params = getParams(viewerContext, safetyLevel) + profileStitch( + Stitch.traverse(contentIds) { contentId => + visibilityRuleEngine( + EvaluationContext(safetyLevel, params, NullStatsReceiver), + safetyLevel, + new VisibilityResultBuilder(contentId, featureMapProvider(contentId, safetyLevel)), + enableShortCircuiting + ).liftToTry + }, + safetyLevel.name + ) + } + + def runRuleEngineBatch( + contentIds: Seq[ContentId], + featureMapProvider: (ContentId, SafetyLevel) => FeatureMap, + evaluationContextBuilder: EvaluationContext.Builder, + safetyLevel: SafetyLevel + ): Stitch[Seq[Try[VisibilityResult]]] = { + val evaluationContext = evaluationContextBuilder.build(safetyLevel) + profileStitch( + Stitch.traverse(contentIds) { contentId => + visibilityRuleEngine( + evaluationContext, + safetyLevel, + new VisibilityResultBuilder(contentId, featureMapProvider(contentId, safetyLevel)), + enableShortCircuiting + ).liftToTry + }, + safetyLevel.name + ) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/builder/BUILD new file mode 100644 index 000000000..46f9c1147 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/BUILD @@ -0,0 +1,29 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "3rdparty/jvm/com/twitter/src/java/com/twitter/logpipeline/client:logpipeline-event-publisher-thin", + "configapi/configapi-core", + "decider/src/main/scala", + "servo/util/src/main/scala", + "src/thrift/com/twitter/search/common:constants-scala", + "src/thrift/com/twitter/spam/rtf:safety-level-scala", + "src/thrift/com/twitter/spam/rtf:safety-result-scala", + "src/thrift/com/twitter/spam/rtf:tweet-rtf-event-scala", + "stitch/stitch-core", + "util/util-stats/src/main/scala/com/twitter/finagle/stats", + "visibility/common/src/main/scala/com/twitter/visibility/common/actions", + "visibility/common/src/main/scala/com/twitter/visibility/common/actions/converter/scala", + "visibility/common/src/main/scala/com/twitter/visibility/common/stitch", + "visibility/common/src/main/thrift/com/twitter/visibility:action-scala", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/configs", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/models", + "visibility/lib/src/main/scala/com/twitter/visibility/rules", + "visibility/lib/src/main/scala/com/twitter/visibility/util", + "visibility/lib/src/main/thrift/com/twitter/visibility/logging:vf-logging-scala", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/FeatureMapBuilder.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/FeatureMapBuilder.scala new file mode 100644 index 000000000..946879c5e --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/FeatureMapBuilder.scala @@ -0,0 +1,64 @@ +package com.twitter.visibility.builder + +import com.twitter.finagle.stats.NullStatsReceiver +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.servo.util.Gate +import com.twitter.stitch.Stitch +import com.twitter.visibility.features._ +import com.twitter.visibility.common.stitch.StitchHelpers +import scala.collection.mutable + +object FeatureMapBuilder { + type Build = Seq[FeatureMapBuilder => FeatureMapBuilder] => FeatureMap + + def apply( + statsReceiver: StatsReceiver = NullStatsReceiver, + enableStitchProfiling: Gate[Unit] = Gate.False + ): Build = + fns => + Function + .chain(fns).apply( + new FeatureMapBuilder(statsReceiver, enableStitchProfiling) + ).build +} + +class FeatureMapBuilder private[builder] ( + statsReceiver: StatsReceiver, + enableStitchProfiling: Gate[Unit] = Gate.False) { + + private[this] val hydratedScope = + statsReceiver.scope("visibility_result_builder").scope("hydrated") + + val mapBuilder: mutable.Builder[(Feature[_], Stitch[_]), Map[Feature[_], Stitch[_]]] = + Map.newBuilder[Feature[_], Stitch[_]] + + val constantMapBuilder: mutable.Builder[(Feature[_], Any), Map[Feature[_], Any]] = + Map.newBuilder[Feature[_], Any] + + def build: FeatureMap = new FeatureMap(mapBuilder.result(), constantMapBuilder.result()) + + def withConstantFeature[T](feature: Feature[T], value: T): FeatureMapBuilder = { + val anyValue: Any = value.asInstanceOf[Any] + constantMapBuilder += (feature -> anyValue) + this + } + + def withFeature[T](feature: Feature[T], stitch: Stitch[T]): FeatureMapBuilder = { + val profiledStitch = if (enableStitchProfiling()) { + val featureScope = hydratedScope.scope(feature.name) + StitchHelpers.profileStitch(stitch, Seq(hydratedScope, featureScope)) + } else { + stitch + } + + val featureStitchRef = Stitch.ref(profiledStitch) + + mapBuilder += FeatureMap.rescueFeatureTuple(feature -> featureStitchRef) + + this + } + + def withConstantFeature[T](feature: Feature[T], option: Option[T]): FeatureMapBuilder = { + option.map(withConstantFeature(feature, _)).getOrElse(this) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/VerdictLogger.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/VerdictLogger.scala new file mode 100644 index 000000000..a47a35fa2 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/VerdictLogger.scala @@ -0,0 +1,187 @@ +package com.twitter.visibility.builder + +import com.twitter.datatools.entityservice.entities.thriftscala.FleetInterstitial +import com.twitter.decider.Decider +import com.twitter.decider.Decider.NullDecider +import com.twitter.finagle.stats.NullStatsReceiver +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.logpipeline.client.common.EventPublisher +import com.twitter.logpipeline.client.EventPublisherManager +import com.twitter.logpipeline.client.serializers.EventLogMsgThriftStructSerializer +import com.twitter.spam.rtf.thriftscala.SafetyLevel +import com.twitter.visibility.builder.VerdictLogger.FailureCounterName +import com.twitter.visibility.builder.VerdictLogger.SuccessCounterName +import com.twitter.visibility.features.Feature +import com.twitter.visibility.logging.thriftscala.ActionSource +import com.twitter.visibility.logging.thriftscala.EntityId +import com.twitter.visibility.logging.thriftscala.EntityIdType +import com.twitter.visibility.logging.thriftscala.EntityIdValue +import com.twitter.visibility.logging.thriftscala.HealthActionType +import com.twitter.visibility.logging.thriftscala.MisinfoPolicyCategory +import com.twitter.visibility.logging.thriftscala.VFLibType +import com.twitter.visibility.logging.thriftscala.VFVerdictLogEntry +import com.twitter.visibility.models.ContentId +import com.twitter.visibility.rules._ + +object VerdictLogger { + + private val BaseStatsNamespace = "vf_verdict_logger" + private val FailureCounterName = "failures" + private val SuccessCounterName = "successes" + val LogCategoryName: String = "visibility_filtering_verdicts" + + val Empty: VerdictLogger = new VerdictLogger(NullStatsReceiver, NullDecider, None) + + def apply( + statsReceiver: StatsReceiver, + decider: Decider + ): VerdictLogger = { + val eventPublisher: EventPublisher[VFVerdictLogEntry] = + EventPublisherManager + .newScribePublisherBuilder( + LogCategoryName, + EventLogMsgThriftStructSerializer.getNewSerializer[VFVerdictLogEntry]()).build() + new VerdictLogger(statsReceiver.scope(BaseStatsNamespace), decider, Some(eventPublisher)) + } +} + +class VerdictLogger( + statsReceiver: StatsReceiver, + decider: Decider, + publisherOpt: Option[EventPublisher[VFVerdictLogEntry]]) { + + def log( + verdictLogEntry: VFVerdictLogEntry, + publisher: EventPublisher[VFVerdictLogEntry] + ): Unit = { + publisher + .publish(verdictLogEntry) + .onSuccess(_ => statsReceiver.counter(SuccessCounterName).incr()) + .onFailure { e => + statsReceiver.counter(FailureCounterName).incr() + statsReceiver.scope(FailureCounterName).counter(e.getClass.getName).incr() + } + } + + private def toEntityId(contentId: ContentId): Option[EntityId] = { + contentId match { + case ContentId.TweetId(id) => Some(EntityId(EntityIdType.TweetId, EntityIdValue.EntityId(id))) + case ContentId.UserId(id) => Some(EntityId(EntityIdType.UserId, EntityIdValue.EntityId(id))) + case ContentId.QuotedTweetRelationship(outerTweetId, _) => + Some(EntityId(EntityIdType.TweetId, EntityIdValue.EntityId(outerTweetId))) + case ContentId.NotificationId(Some(id)) => + Some(EntityId(EntityIdType.NotificationId, EntityIdValue.EntityId(id))) + case ContentId.DmId(id) => Some(EntityId(EntityIdType.DmId, EntityIdValue.EntityId(id))) + case ContentId.BlenderTweetId(id) => + Some(EntityId(EntityIdType.TweetId, EntityIdValue.EntityId(id))) + case ContentId.SpacePlusUserId(_) => + } + } + + private def getLogEntryData( + actingRule: Option[Rule], + secondaryActingRules: Seq[Rule], + verdict: Action, + secondaryVerdicts: Seq[Action], + resolvedFeatureMap: Map[Feature[_], Any] + ): (Seq[String], Seq[ActionSource], Seq[HealthActionType], Option[FleetInterstitial]) = { + actingRule + .filter { + case decideredRule: DoesLogVerdictDecidered => + decider.isAvailable(decideredRule.verdictLogDeciderKey.toString) + case rule: DoesLogVerdict => true + case _ => false + } + .map { primaryRule => + val secondaryRulesAndVerdicts = secondaryActingRules zip secondaryVerdicts + var actingRules: Seq[Rule] = Seq(primaryRule) + var actingRuleNames: Seq[String] = Seq(primaryRule.name) + var actionSources: Seq[ActionSource] = Seq() + var healthActionTypes: Seq[HealthActionType] = Seq(verdict.toHealthActionTypeThrift.get) + + val misinfoPolicyCategory: Option[FleetInterstitial] = { + verdict match { + case softIntervention: SoftIntervention => + softIntervention.fleetInterstitial + case tweetInterstitial: TweetInterstitial => + tweetInterstitial.softIntervention.flatMap(_.fleetInterstitial) + case _ => None + } + } + + secondaryRulesAndVerdicts.foreach(ruleAndVerdict => { + if (ruleAndVerdict._1.isInstanceOf[DoesLogVerdict]) { + actingRules = actingRules :+ ruleAndVerdict._1 + actingRuleNames = actingRuleNames :+ ruleAndVerdict._1.name + healthActionTypes = healthActionTypes :+ ruleAndVerdict._2.toHealthActionTypeThrift.get + } + }) + + actingRules.foreach(rule => { + rule.actionSourceBuilder + .flatMap(_.build(resolvedFeatureMap, verdict)) + .map(actionSource => { + actionSources = actionSources :+ actionSource + }) + }) + (actingRuleNames, actionSources, healthActionTypes, misinfoPolicyCategory) + } + .getOrElse((Seq.empty[String], Seq.empty[ActionSource], Seq.empty[HealthActionType], None)) + } + + def scribeVerdict( + visibilityResult: VisibilityResult, + safetyLevel: SafetyLevel, + vfLibType: VFLibType, + viewerId: Option[Long] = None + ): Unit = { + publisherOpt.foreach { publisher => + toEntityId(visibilityResult.contentId).foreach { entityId => + visibilityResult.verdict.toHealthActionTypeThrift.foreach { healthActionType => + val (actioningRules, actionSources, healthActionTypes, misinfoPolicyCategory) = + getLogEntryData( + actingRule = visibilityResult.actingRule, + secondaryActingRules = visibilityResult.secondaryActingRules, + verdict = visibilityResult.verdict, + secondaryVerdicts = visibilityResult.secondaryVerdicts, + resolvedFeatureMap = visibilityResult.resolvedFeatureMap + ) + + if (actioningRules.nonEmpty) { + log( + VFVerdictLogEntry( + entityId = entityId, + viewerId = viewerId, + timestampMsec = System.currentTimeMillis(), + vfLibType = vfLibType, + healthActionType = healthActionType, + safetyLevel = safetyLevel, + actioningRules = actioningRules, + actionSources = actionSources, + healthActionTypes = healthActionTypes, + misinfoPolicyCategory = + fleetInterstitialToMisinfoPolicyCategory(misinfoPolicyCategory) + ), + publisher + ) + } + } + } + } + } + + def fleetInterstitialToMisinfoPolicyCategory( + fleetInterstitialOption: Option[FleetInterstitial] + ): Option[MisinfoPolicyCategory] = { + fleetInterstitialOption.map { + case FleetInterstitial.Generic => + MisinfoPolicyCategory.Generic + case FleetInterstitial.Samm => + MisinfoPolicyCategory.Samm + case FleetInterstitial.CivicIntegrity => + MisinfoPolicyCategory.CivicIntegrity + case _ => MisinfoPolicyCategory.Unknown + } + } + +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/VisibilityResult.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/VisibilityResult.scala new file mode 100644 index 000000000..bdf8764eb --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/VisibilityResult.scala @@ -0,0 +1,112 @@ +package com.twitter.visibility.builder + +import com.twitter.spam.rtf.thriftscala.SafetyResult +import com.twitter.visibility.common.actions.converter.scala.DropReasonConverter +import com.twitter.visibility.rules.ComposableActions._ +import com.twitter.visibility.features.Feature +import com.twitter.visibility.features.FeatureMap +import com.twitter.visibility.models.ContentId +import com.twitter.visibility.rules._ +import com.twitter.visibility.{thriftscala => t} + +case class VisibilityResult( + contentId: ContentId, + featureMap: FeatureMap = FeatureMap.empty, + ruleResultMap: Map[Rule, RuleResult] = Map.empty, + verdict: Action = Allow, + finished: Boolean = false, + actingRule: Option[Rule] = None, + secondaryActingRules: Seq[Rule] = Seq(), + secondaryVerdicts: Seq[Action] = Seq(), + resolvedFeatureMap: Map[Feature[_], Any] = Map.empty) { + + def getSafetyResult: SafetyResult = + verdict match { + case InterstitialLimitedEngagements(reason: Reason, _, _, _) + if PublicInterest.Reasons + .contains(reason) => + SafetyResult( + Some(PublicInterest.ReasonToSafetyResultReason(reason)), + verdict.toActionThrift() + ) + case ComposableActionsWithInterstitialLimitedEngagements(tweetInterstitial) + if PublicInterest.Reasons.contains(tweetInterstitial.reason) => + SafetyResult( + Some(PublicInterest.ReasonToSafetyResultReason(tweetInterstitial.reason)), + verdict.toActionThrift() + ) + case FreedomOfSpeechNotReachReason(appealableReason) => + SafetyResult( + Some(FreedomOfSpeechNotReach.reasonToSafetyResultReason(appealableReason)), + verdict.toActionThrift() + ) + case _ => SafetyResult(None, verdict.toActionThrift()) + } + + def getUserVisibilityResult: Option[t.UserVisibilityResult] = + (verdict match { + case Drop(reason, _) => + Some( + t.UserAction.Drop(t.Drop(Reason.toDropReason(reason).map(DropReasonConverter.toThrift)))) + case _ => None + }).map(userAction => t.UserVisibilityResult(Some(userAction))) +} + +object VisibilityResult { + class Builder { + var featureMap: FeatureMap = FeatureMap.empty + var ruleResultMap: Map[Rule, RuleResult] = Map.empty + var verdict: Action = Allow + var finished: Boolean = false + var actingRule: Option[Rule] = None + var secondaryActingRules: Seq[Rule] = Seq() + var secondaryVerdicts: Seq[Action] = Seq() + var resolvedFeatureMap: Map[Feature[_], Any] = Map.empty + + def withFeatureMap(featureMapBld: FeatureMap) = { + featureMap = featureMapBld + this + } + + def withRuleResultMap(ruleResultMapBld: Map[Rule, RuleResult]) = { + ruleResultMap = ruleResultMapBld + this + } + + def withVerdict(verdictBld: Action) = { + verdict = verdictBld + this + } + + def withFinished(finishedBld: Boolean) = { + finished = finishedBld + this + } + + def withActingRule(actingRuleBld: Option[Rule]) = { + actingRule = actingRuleBld + this + } + + def withSecondaryActingRules(secondaryActingRulesBld: Seq[Rule]) = { + secondaryActingRules = secondaryActingRulesBld + this + } + + def withSecondaryVerdicts(secondaryVerdictsBld: Seq[Action]) = { + secondaryVerdicts = secondaryVerdictsBld + this + } + + def build(contentId: ContentId) = VisibilityResult( + contentId, + featureMap, + ruleResultMap, + verdict, + finished, + actingRule, + secondaryActingRules, + secondaryVerdicts, + resolvedFeatureMap) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/VisibilityResultBuilder.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/VisibilityResultBuilder.scala new file mode 100644 index 000000000..83731eb88 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/VisibilityResultBuilder.scala @@ -0,0 +1,114 @@ +package com.twitter.visibility.builder + +import com.twitter.visibility.features.Feature +import com.twitter.visibility.features.FeatureMap +import com.twitter.visibility.models.ContentId +import com.twitter.visibility.rules.Action +import com.twitter.visibility.rules.Allow +import com.twitter.visibility.rules.EvaluationContext +import com.twitter.visibility.rules.FailClosedException +import com.twitter.visibility.rules.FeaturesFailedException +import com.twitter.visibility.rules.MissingFeaturesException +import com.twitter.visibility.rules.Rule +import com.twitter.visibility.rules.RuleFailedException +import com.twitter.visibility.rules.RuleResult +import com.twitter.visibility.rules.State.FeatureFailed +import com.twitter.visibility.rules.State.MissingFeature +import com.twitter.visibility.rules.State.RuleFailed + +class VisibilityResultBuilder( + val contentId: ContentId, + val featureMap: FeatureMap = FeatureMap.empty, + private var ruleResultMap: Map[Rule, RuleResult] = Map.empty) { + private var mapBuilder = Map.newBuilder[Rule, RuleResult] + mapBuilder ++= ruleResultMap + var verdict: Action = Allow + var finished: Boolean = false + var features: FeatureMap = featureMap + var actingRule: Option[Rule] = None + var secondaryVerdicts: Seq[Action] = Seq() + var secondaryActingRules: Seq[Rule] = Seq() + var resolvedFeatureMap: Map[Feature[_], Any] = Map.empty + + def ruleResults: Map[Rule, RuleResult] = mapBuilder.result() + + def withFeatureMap(featureMap: FeatureMap): VisibilityResultBuilder = { + this.features = featureMap + this + } + + def withRuleResultMap(ruleResultMap: Map[Rule, RuleResult]): VisibilityResultBuilder = { + this.ruleResultMap = ruleResultMap + mapBuilder = Map.newBuilder[Rule, RuleResult] + mapBuilder ++= ruleResultMap + this + } + + def withRuleResult(rule: Rule, result: RuleResult): VisibilityResultBuilder = { + mapBuilder += ((rule, result)) + this + } + + def withVerdict(verdict: Action, ruleOpt: Option[Rule] = None): VisibilityResultBuilder = { + this.verdict = verdict + this.actingRule = ruleOpt + this + } + + def withSecondaryVerdict(verdict: Action, rule: Rule): VisibilityResultBuilder = { + this.secondaryVerdicts = this.secondaryVerdicts :+ verdict + this.secondaryActingRules = this.secondaryActingRules :+ rule + this + } + + def withFinished(finished: Boolean): VisibilityResultBuilder = { + this.finished = finished + this + } + + def withResolvedFeatureMap( + resolvedFeatureMap: Map[Feature[_], Any] + ): VisibilityResultBuilder = { + this.resolvedFeatureMap = resolvedFeatureMap + this + } + + def isVerdictComposable(): Boolean = this.verdict.isComposable + + def failClosedException(evaluationContext: EvaluationContext): Option[FailClosedException] = { + mapBuilder + .result().collect { + case (r: Rule, RuleResult(_, MissingFeature(mf))) + if r.shouldFailClosed(evaluationContext.params) => + Some(MissingFeaturesException(r.name, mf)) + case (r: Rule, RuleResult(_, FeatureFailed(ff))) + if r.shouldFailClosed(evaluationContext.params) => + Some(FeaturesFailedException(r.name, ff)) + case (r: Rule, RuleResult(_, RuleFailed(t))) + if r.shouldFailClosed(evaluationContext.params) => + Some(RuleFailedException(r.name, t)) + }.toList.foldLeft(None: Option[FailClosedException]) { (acc, arg) => + (acc, arg) match { + case (None, Some(_)) => arg + case (Some(FeaturesFailedException(_, _)), Some(MissingFeaturesException(_, _))) => arg + case (Some(RuleFailedException(_, _)), Some(MissingFeaturesException(_, _))) => arg + case (Some(RuleFailedException(_, _)), Some(FeaturesFailedException(_, _))) => arg + case _ => acc + } + } + } + + def build: VisibilityResult = { + VisibilityResult( + contentId = contentId, + featureMap = features, + ruleResultMap = mapBuilder.result(), + verdict = verdict, + finished = finished, + actingRule = actingRule, + secondaryActingRules = secondaryActingRules, + secondaryVerdicts = secondaryVerdicts, + resolvedFeatureMap = resolvedFeatureMap + ) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/common/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/builder/common/BUILD new file mode 100644 index 000000000..5d44ccfd1 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/common/BUILD @@ -0,0 +1,32 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "3rdparty/jvm/com/google/guava", + "communities/thrift/src/main/thrift/com/twitter/communities:thrift-scala", + "communities/thrift/src/main/thrift/com/twitter/communities/moderation:thrift-scala", + "escherbird/src/thrift/com/twitter/escherbird/softintervention:softintervention_thrift-scala", + "snowflake/src/main/scala/com/twitter/snowflake/id", + "src/thrift/com/twitter/context:twitter-context-scala", + "src/thrift/com/twitter/escherbird/common:common-scala", + "src/thrift/com/twitter/gizmoduck:user-thrift-scala", + "src/thrift/com/twitter/search/common:constants-scala", + "src/thrift/com/twitter/spam/rtf:safety-level-scala", + "src/thrift/com/twitter/spam/rtf:tweet-rtf-event-scala", + "src/thrift/com/twitter/tweetypie:tweet-scala", + "stitch/stitch-core", + "tweetypie/src/scala/com/twitter/tweetypie/additionalfields", + "twitter-context/src/main/scala", + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/common/src/main/scala/com/twitter/visibility/common/stitch", + "visibility/lib/src/main/resources/config", + "visibility/lib/src/main/scala/com/twitter/visibility", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/users", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/common/blender", + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/common/tweets", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/common/MutedKeywordFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/common/MutedKeywordFeatures.scala new file mode 100644 index 000000000..eb0a21663 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/common/MutedKeywordFeatures.scala @@ -0,0 +1,228 @@ +package com.twitter.visibility.builder.common + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.gizmoduck.thriftscala.MuteOption +import com.twitter.gizmoduck.thriftscala.MuteSurface +import com.twitter.gizmoduck.thriftscala.{MutedKeyword => GdMutedKeyword} +import com.twitter.servo.util.Gate +import com.twitter.stitch.Stitch +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.common._ +import com.twitter.visibility.features._ +import com.twitter.visibility.models.{MutedKeyword => VfMutedKeyword} +import java.util.Locale + +class MutedKeywordFeatures( + userSource: UserSource, + userRelationshipSource: UserRelationshipSource, + keywordMatcher: KeywordMatcher.Matcher = KeywordMatcher.TestMatcher, + statsReceiver: StatsReceiver, + enableFollowCheckInMutedKeyword: Gate[Unit] = Gate.False) { + + private[this] val scopedStatsReceiver: StatsReceiver = + statsReceiver.scope("muted_keyword_features") + + private[this] val requests = scopedStatsReceiver.counter("requests") + + private[this] val viewerMutesKeywordInTweetForHomeTimeline = + scopedStatsReceiver.scope(ViewerMutesKeywordInTweetForHomeTimeline.name).counter("requests") + private[this] val viewerMutesKeywordInTweetForTweetReplies = + scopedStatsReceiver.scope(ViewerMutesKeywordInTweetForTweetReplies.name).counter("requests") + private[this] val viewerMutesKeywordInTweetForNotifications = + scopedStatsReceiver.scope(ViewerMutesKeywordInTweetForNotifications.name).counter("requests") + private[this] val excludeFollowingForMutedKeywordsRequests = + scopedStatsReceiver.scope("exclude_following").counter("requests") + private[this] val viewerMutesKeywordInTweetForAllSurfaces = + scopedStatsReceiver.scope(ViewerMutesKeywordInTweetForAllSurfaces.name).counter("requests") + + def forTweet( + tweet: Tweet, + viewerId: Option[Long], + authorId: Long + ): FeatureMapBuilder => FeatureMapBuilder = { featureMapBuilder => + requests.incr() + viewerMutesKeywordInTweetForHomeTimeline.incr() + viewerMutesKeywordInTweetForTweetReplies.incr() + viewerMutesKeywordInTweetForNotifications.incr() + viewerMutesKeywordInTweetForAllSurfaces.incr() + + val keywordsBySurface = allMutedKeywords(viewerId) + + val keywordsWithoutDefinedSurface = allMutedKeywordsWithoutDefinedSurface(viewerId) + + featureMapBuilder + .withFeature( + ViewerMutesKeywordInTweetForHomeTimeline, + tweetContainsMutedKeyword( + tweet, + keywordsBySurface, + MuteSurface.HomeTimeline, + viewerId, + authorId + ) + ) + .withFeature( + ViewerMutesKeywordInTweetForTweetReplies, + tweetContainsMutedKeyword( + tweet, + keywordsBySurface, + MuteSurface.TweetReplies, + viewerId, + authorId + ) + ) + .withFeature( + ViewerMutesKeywordInTweetForNotifications, + tweetContainsMutedKeyword( + tweet, + keywordsBySurface, + MuteSurface.Notifications, + viewerId, + authorId + ) + ) + .withFeature( + ViewerMutesKeywordInTweetForAllSurfaces, + tweetContainsMutedKeywordWithoutDefinedSurface( + tweet, + keywordsWithoutDefinedSurface, + viewerId, + authorId + ) + ) + } + + def allMutedKeywords(viewerId: Option[Long]): Stitch[Map[MuteSurface, Seq[GdMutedKeyword]]] = + viewerId + .map { id => userSource.getAllMutedKeywords(id) }.getOrElse(Stitch.value(Map.empty)) + + def allMutedKeywordsWithoutDefinedSurface(viewerId: Option[Long]): Stitch[Seq[GdMutedKeyword]] = + viewerId + .map { id => userSource.getAllMutedKeywordsWithoutDefinedSurface(id) }.getOrElse( + Stitch.value(Seq.empty)) + + private def mutingKeywordsText( + mutedKeywords: Seq[GdMutedKeyword], + muteSurface: MuteSurface, + viewerIdOpt: Option[Long], + authorId: Long + ): Stitch[Option[String]] = { + if (muteSurface == MuteSurface.HomeTimeline && mutedKeywords.nonEmpty) { + Stitch.value(Some(mutedKeywords.map(_.keyword).mkString(","))) + } else { + mutedKeywords.partition(kw => + kw.muteOptions.contains(MuteOption.ExcludeFollowingAccounts)) match { + case (_, mutedKeywordsFromAnyone) if mutedKeywordsFromAnyone.nonEmpty => + Stitch.value(Some(mutedKeywordsFromAnyone.map(_.keyword).mkString(","))) + case (mutedKeywordsExcludeFollowing, _) + if mutedKeywordsExcludeFollowing.nonEmpty && enableFollowCheckInMutedKeyword() => + excludeFollowingForMutedKeywordsRequests.incr() + viewerIdOpt match { + case Some(viewerId) => + userRelationshipSource.follows(viewerId, authorId).map { + case true => + case false => Some(mutedKeywordsExcludeFollowing.map(_.keyword).mkString(",")) + } + case _ => Stitch.None + } + case (_, _) => Stitch.None + } + } + } + + private def mutingKeywordsTextWithoutDefinedSurface( + mutedKeywords: Seq[GdMutedKeyword], + viewerIdOpt: Option[Long], + authorId: Long + ): Stitch[Option[String]] = { + mutedKeywords.partition(kw => + kw.muteOptions.contains(MuteOption.ExcludeFollowingAccounts)) match { + case (_, mutedKeywordsFromAnyone) if mutedKeywordsFromAnyone.nonEmpty => + Stitch.value(Some(mutedKeywordsFromAnyone.map(_.keyword).mkString(","))) + case (mutedKeywordsExcludeFollowing, _) + if mutedKeywordsExcludeFollowing.nonEmpty && enableFollowCheckInMutedKeyword() => + excludeFollowingForMutedKeywordsRequests.incr() + viewerIdOpt match { + case Some(viewerId) => + userRelationshipSource.follows(viewerId, authorId).map { + case true => + case false => Some(mutedKeywordsExcludeFollowing.map(_.keyword).mkString(",")) + } + case _ => Stitch.None + } + case (_, _) => Stitch.None + } + } + + def tweetContainsMutedKeyword( + tweet: Tweet, + mutedKeywordMap: Stitch[Map[MuteSurface, Seq[GdMutedKeyword]]], + muteSurface: MuteSurface, + viewerIdOpt: Option[Long], + authorId: Long + ): Stitch[VfMutedKeyword] = { + mutedKeywordMap.flatMap { keywordMap => + if (keywordMap.isEmpty) { + Stitch.value(VfMutedKeyword(None)) + } else { + val mutedKeywords = keywordMap.getOrElse(muteSurface, Nil) + val matchTweetFn: KeywordMatcher.MatchTweet = keywordMatcher(mutedKeywords) + val locale = tweet.language.map(l => Locale.forLanguageTag(l.language)) + val text = tweet.coreData.get.text + + matchTweetFn(locale, text).flatMap { results => + mutingKeywordsText(results, muteSurface, viewerIdOpt, authorId).map(VfMutedKeyword) + } + } + } + } + + def tweetContainsMutedKeywordWithoutDefinedSurface( + tweet: Tweet, + mutedKeywordSeq: Stitch[Seq[GdMutedKeyword]], + viewerIdOpt: Option[Long], + authorId: Long + ): Stitch[VfMutedKeyword] = { + mutedKeywordSeq.flatMap { mutedKeyword => + if (mutedKeyword.isEmpty) { + Stitch.value(VfMutedKeyword(None)) + } else { + val matchTweetFn: KeywordMatcher.MatchTweet = keywordMatcher(mutedKeyword) + val locale = tweet.language.map(l => Locale.forLanguageTag(l.language)) + val text = tweet.coreData.get.text + + matchTweetFn(locale, text).flatMap { results => + mutingKeywordsTextWithoutDefinedSurface(results, viewerIdOpt, authorId).map( + VfMutedKeyword + ) + } + } + } + } + def spaceTitleContainsMutedKeyword( + spaceTitle: String, + spaceLanguageOpt: Option[String], + mutedKeywordMap: Stitch[Map[MuteSurface, Seq[GdMutedKeyword]]], + muteSurface: MuteSurface, + ): Stitch[VfMutedKeyword] = { + mutedKeywordMap.flatMap { keywordMap => + if (keywordMap.isEmpty) { + Stitch.value(VfMutedKeyword(None)) + } else { + val mutedKeywords = keywordMap.getOrElse(muteSurface, Nil) + val matchTweetFn: KeywordMatcher.MatchTweet = keywordMatcher(mutedKeywords) + + val locale = spaceLanguageOpt.map(l => Locale.forLanguageTag(l)) + matchTweetFn(locale, spaceTitle).flatMap { results => + if (results.nonEmpty) { + Stitch.value(Some(results.map(_.keyword).mkString(","))).map(VfMutedKeyword) + } else { + Stitch.None.map(VfMutedKeyword) + } + } + } + } + } + +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/dms/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/builder/dms/BUILD new file mode 100644 index 000000000..3c769bb73 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/dms/BUILD @@ -0,0 +1,23 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "src/thrift/com/twitter/convosvc:convosvc-scala", + "src/thrift/com/twitter/convosvc/internal:internal-scala", + "src/thrift/com/twitter/spam/rtf:safety-level-scala", + "stitch/stitch-core", + "stitch/stitch-core/src/main/scala/com/twitter/stitch", + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/common/src/main/scala/com/twitter/visibility/common/dm_sources", + "visibility/lib/src/main/resources/config", + "visibility/lib/src/main/scala/com/twitter/visibility", + "visibility/lib/src/main/scala/com/twitter/visibility/builder", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/users", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/util", + "visibility/lib/src/main/thrift/com/twitter/visibility/safety_label_store:safety-label-store-scala", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/dms/DmConversationFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/dms/DmConversationFeatures.scala new file mode 100644 index 000000000..ad21e40ad --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/dms/DmConversationFeatures.scala @@ -0,0 +1,196 @@ +package com.twitter.visibility.builder.dms + +import com.twitter.convosvc.thriftscala.ConversationQuery +import com.twitter.convosvc.thriftscala.ConversationQueryOptions +import com.twitter.convosvc.thriftscala.ConversationType +import com.twitter.convosvc.thriftscala.TimelineLookupState +import com.twitter.stitch.NotFound +import com.twitter.stitch.Stitch +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.builder.users.AuthorFeatures +import com.twitter.visibility.common.DmConversationId +import com.twitter.visibility.common.UserId +import com.twitter.visibility.common.dm_sources.DmConversationSource +import com.twitter.visibility.features._ + +case class InvalidDmConversationFeatureException(message: String) extends Exception(message) + +class DmConversationFeatures( + dmConversationSource: DmConversationSource, + authorFeatures: AuthorFeatures) { + + def forDmConversationId( + dmConversationId: DmConversationId, + viewerIdOpt: Option[UserId] + ): FeatureMapBuilder => FeatureMapBuilder = + _.withFeature( + DmConversationIsOneToOneConversation, + dmConversationIsOneToOneConversation(dmConversationId, viewerIdOpt)) + .withFeature( + DmConversationHasEmptyTimeline, + dmConversationHasEmptyTimeline(dmConversationId, viewerIdOpt)) + .withFeature( + DmConversationHasValidLastReadableEventId, + dmConversationHasValidLastReadableEventId(dmConversationId, viewerIdOpt)) + .withFeature( + DmConversationInfoExists, + dmConversationInfoExists(dmConversationId, viewerIdOpt)) + .withFeature( + DmConversationTimelineExists, + dmConversationTimelineExists(dmConversationId, viewerIdOpt)) + .withFeature( + AuthorIsSuspended, + dmConversationHasSuspendedParticipant(dmConversationId, viewerIdOpt)) + .withFeature( + AuthorIsDeactivated, + dmConversationHasDeactivatedParticipant(dmConversationId, viewerIdOpt)) + .withFeature( + AuthorIsErased, + dmConversationHasErasedParticipant(dmConversationId, viewerIdOpt)) + .withFeature( + ViewerIsDmConversationParticipant, + viewerIsDmConversationParticipant(dmConversationId, viewerIdOpt)) + + def dmConversationIsOneToOneConversation( + dmConversationId: DmConversationId, + viewerIdOpt: Option[UserId] + ): Stitch[Boolean] = + viewerIdOpt match { + case Some(viewerId) => + dmConversationSource.getConversationType(dmConversationId, viewerId).flatMap { + case Some(ConversationType.OneToOneDm | ConversationType.SecretOneToOneDm) => + Stitch.True + case None => + Stitch.exception(InvalidDmConversationFeatureException("Conversation type not found")) + case _ => Stitch.False + } + case _ => Stitch.exception(InvalidDmConversationFeatureException("Viewer id missing")) + } + + private[dms] def dmConversationHasEmptyTimeline( + dmConversationId: DmConversationId, + viewerIdOpt: Option[UserId] + ): Stitch[Boolean] = + dmConversationSource + .getConversationTimelineEntries( + dmConversationId, + ConversationQuery( + conversationId = Some(dmConversationId), + options = Some( + ConversationQueryOptions( + perspectivalUserId = viewerIdOpt, + hydrateEvents = Some(false), + supportsReactions = Some(true) + ) + ), + maxCount = 10 + ) + ).map(_.forall(entries => entries.isEmpty)) + + private[dms] def dmConversationHasValidLastReadableEventId( + dmConversationId: DmConversationId, + viewerIdOpt: Option[UserId] + ): Stitch[Boolean] = + viewerIdOpt match { + case Some(viewerId) => + dmConversationSource + .getConversationLastReadableEventId(dmConversationId, viewerId).map(_.exists(id => + id > 0L)) + case _ => Stitch.exception(InvalidDmConversationFeatureException("Viewer id missing")) + } + + private[dms] def dmConversationInfoExists( + dmConversationId: DmConversationId, + viewerIdOpt: Option[UserId] + ): Stitch[Boolean] = + viewerIdOpt match { + case Some(viewerId) => + dmConversationSource + .getDmConversationInfo(dmConversationId, viewerId).map(_.isDefined) + case _ => Stitch.exception(InvalidDmConversationFeatureException("Viewer id missing")) + } + + private[dms] def dmConversationTimelineExists( + dmConversationId: DmConversationId, + viewerIdOpt: Option[UserId] + ): Stitch[Boolean] = + dmConversationSource + .getConversationTimelineState( + dmConversationId, + ConversationQuery( + conversationId = Some(dmConversationId), + options = Some( + ConversationQueryOptions( + perspectivalUserId = viewerIdOpt, + hydrateEvents = Some(false), + supportsReactions = Some(true) + ) + ), + maxCount = 1 + ) + ).map { + case Some(TimelineLookupState.NotFound) | None => false + case _ => true + } + + private[dms] def anyConversationParticipantMatchesCondition( + condition: UserId => Stitch[Boolean], + dmConversationId: DmConversationId, + viewerIdOpt: Option[UserId] + ): Stitch[Boolean] = + viewerIdOpt match { + case Some(viewerId) => + dmConversationSource + .getConversationParticipantIds(dmConversationId, viewerId).flatMap { + case Some(participants) => + Stitch + .collect(participants.map(condition)).map(_.contains(true)).rescue { + case NotFound => + Stitch.exception(InvalidDmConversationFeatureException("User not found")) + } + case _ => Stitch.False + } + case _ => Stitch.exception(InvalidDmConversationFeatureException("Viewer id missing")) + } + + def dmConversationHasSuspendedParticipant( + dmConversationId: DmConversationId, + viewerIdOpt: Option[UserId] + ): Stitch[Boolean] = + anyConversationParticipantMatchesCondition( + participant => authorFeatures.authorIsSuspended(participant), + dmConversationId, + viewerIdOpt) + + def dmConversationHasDeactivatedParticipant( + dmConversationId: DmConversationId, + viewerIdOpt: Option[UserId] + ): Stitch[Boolean] = + anyConversationParticipantMatchesCondition( + participant => authorFeatures.authorIsDeactivated(participant), + dmConversationId, + viewerIdOpt) + + def dmConversationHasErasedParticipant( + dmConversationId: DmConversationId, + viewerIdOpt: Option[UserId] + ): Stitch[Boolean] = + anyConversationParticipantMatchesCondition( + participant => authorFeatures.authorIsErased(participant), + dmConversationId, + viewerIdOpt) + + def viewerIsDmConversationParticipant( + dmConversationId: DmConversationId, + viewerIdOpt: Option[UserId] + ): Stitch[Boolean] = + viewerIdOpt match { + case Some(viewerId) => + dmConversationSource + .getConversationParticipantIds(dmConversationId, viewerId).map { + case Some(participants) => participants.contains(viewerId) + case _ => false + } + case _ => Stitch.exception(InvalidDmConversationFeatureException("Viewer id missing")) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/dms/DmEventFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/dms/DmEventFeatures.scala new file mode 100644 index 000000000..797a52268 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/dms/DmEventFeatures.scala @@ -0,0 +1,341 @@ +package com.twitter.visibility.builder.dms + +import com.twitter.convosvc.thriftscala.Event +import com.twitter.convosvc.thriftscala.StoredDelete +import com.twitter.convosvc.thriftscala.StoredPerspectivalMessageInfo +import com.twitter.convosvc.thriftscala.PerspectivalSpamState +import com.twitter.stitch.Stitch +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.builder.users.AuthorFeatures +import com.twitter.visibility.common.DmEventId +import com.twitter.visibility.common.dm_sources.DmEventSource +import com.twitter.visibility.common.UserId +import com.twitter.convosvc.thriftscala.EventType +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.stitch.NotFound +import com.twitter.visibility.common.dm_sources.DmConversationSource +import com.twitter.visibility.features._ + +case class InvalidDmEventFeatureException(message: String) extends Exception(message) + +class DmEventFeatures( + dmEventSource: DmEventSource, + dmConversationSource: DmConversationSource, + authorFeatures: AuthorFeatures, + dmConversationFeatures: DmConversationFeatures, + statsReceiver: StatsReceiver) { + private[this] val scopedStatsReceiver = statsReceiver.scope("dm_event_features") + private[this] val requests = scopedStatsReceiver.counter("requests") + + def forDmEventId( + dmEventId: DmEventId, + viewerId: UserId + ): FeatureMapBuilder => FeatureMapBuilder = { + requests.incr() + + val dmEventStitchRef: Stitch[Option[Event]] = + Stitch.ref(dmEventSource.getDmEvent(dmEventId, viewerId)) + + _.withFeature( + DmEventIsMessageCreateEvent, + isDmEventType(dmEventStitchRef, EventType.MessageCreate)) + .withFeature( + AuthorIsSuspended, + messageCreateEventHasInactiveInitiatingUser( + dmEventStitchRef, + initiatingUser => authorFeatures.authorIsSuspended(initiatingUser)) + ) + .withFeature( + AuthorIsDeactivated, + messageCreateEventHasInactiveInitiatingUser( + dmEventStitchRef, + initiatingUser => authorFeatures.authorIsDeactivated(initiatingUser)) + ) + .withFeature( + AuthorIsErased, + messageCreateEventHasInactiveInitiatingUser( + dmEventStitchRef, + initiatingUser => authorFeatures.authorIsErased(initiatingUser)) + ) + .withFeature( + DmEventOccurredBeforeLastClearedEvent, + dmEventOccurredBeforeLastClearedEvent(dmEventStitchRef, dmEventId, viewerId) + ) + .withFeature( + DmEventOccurredBeforeJoinConversationEvent, + dmEventOccurredBeforeJoinConversationEvent(dmEventStitchRef, dmEventId, viewerId) + ) + .withFeature( + ViewerIsDmConversationParticipant, + dmEventViewerIsDmConversationParticipant(dmEventStitchRef, viewerId) + ) + .withFeature( + DmEventIsDeleted, + dmEventIsDeleted(dmEventStitchRef, dmEventId) + ) + .withFeature( + DmEventIsHidden, + dmEventIsHidden(dmEventStitchRef, dmEventId) + ) + .withFeature( + ViewerIsDmEventInitiatingUser, + viewerIsDmEventInitiatingUser(dmEventStitchRef, viewerId) + ) + .withFeature( + DmEventInOneToOneConversationWithUnavailableUser, + dmEventInOneToOneConversationWithUnavailableUser(dmEventStitchRef, viewerId) + ) + .withFeature( + DmEventIsLastMessageReadUpdateEvent, + isDmEventType(dmEventStitchRef, EventType.LastMessageReadUpdate) + ) + .withFeature( + DmEventIsJoinConversationEvent, + isDmEventType(dmEventStitchRef, EventType.JoinConversation) + ) + .withFeature( + DmEventIsWelcomeMessageCreateEvent, + isDmEventType(dmEventStitchRef, EventType.WelcomeMessageCreate) + ) + .withFeature( + DmEventIsTrustConversationEvent, + isDmEventType(dmEventStitchRef, EventType.TrustConversation) + ) + .withFeature( + DmEventIsCsFeedbackSubmitted, + isDmEventType(dmEventStitchRef, EventType.CsFeedbackSubmitted) + ) + .withFeature( + DmEventIsCsFeedbackDismissed, + isDmEventType(dmEventStitchRef, EventType.CsFeedbackDismissed) + ) + .withFeature( + DmEventIsConversationCreateEvent, + isDmEventType(dmEventStitchRef, EventType.ConversationCreate) + ) + .withFeature( + DmEventInOneToOneConversation, + dmEventInOneToOneConversation(dmEventStitchRef, viewerId) + ) + .withFeature( + DmEventIsPerspectivalJoinConversationEvent, + dmEventIsPerspectivalJoinConversationEvent(dmEventStitchRef, dmEventId, viewerId)) + + } + + private def isDmEventType( + dmEventOptStitch: Stitch[Option[Event]], + eventType: EventType + ): Stitch[Boolean] = + dmEventSource.getEventType(dmEventOptStitch).flatMap { + case Some(_: eventType.type) => + Stitch.True + case None => + Stitch.exception(InvalidDmEventFeatureException(s"$eventType event type not found")) + case _ => + Stitch.False + } + + private def dmEventIsPerspectivalJoinConversationEvent( + dmEventOptStitch: Stitch[Option[Event]], + dmEventId: DmEventId, + viewerId: UserId + ): Stitch[Boolean] = + Stitch + .join( + dmEventSource.getEventType(dmEventOptStitch), + dmEventSource.getConversationId(dmEventOptStitch)).flatMap { + case (Some(EventType.JoinConversation), conversationIdOpt) => + conversationIdOpt match { + case Some(conversationId) => + dmConversationSource + .getParticipantJoinConversationEventId(conversationId, viewerId, viewerId) + .flatMap { + case Some(joinConversationEventId) => + Stitch.value(joinConversationEventId == dmEventId) + case _ => Stitch.False + } + case _ => + Stitch.exception(InvalidDmEventFeatureException("Conversation id not found")) + } + case (None, _) => + Stitch.exception(InvalidDmEventFeatureException("Event type not found")) + case _ => Stitch.False + } + + private def messageCreateEventHasInactiveInitiatingUser( + dmEventOptStitch: Stitch[Option[Event]], + condition: UserId => Stitch[Boolean], + ): Stitch[Boolean] = + Stitch + .join( + dmEventSource.getEventType(dmEventOptStitch), + dmEventSource.getInitiatingUserId(dmEventOptStitch)).flatMap { + case (Some(EventType.MessageCreate), Some(userId)) => + condition(userId).rescue { + case NotFound => + Stitch.exception(InvalidDmEventFeatureException("initiating user not found")) + } + case (None, _) => + Stitch.exception(InvalidDmEventFeatureException("DmEvent type is missing")) + case (Some(EventType.MessageCreate), _) => + Stitch.exception(InvalidDmEventFeatureException("initiating user id is missing")) + case _ => Stitch.False + } + + private def dmEventOccurredBeforeLastClearedEvent( + dmEventOptStitch: Stitch[Option[Event]], + dmEventId: DmEventId, + viewerId: UserId + ): Stitch[Boolean] = { + dmEventSource.getConversationId(dmEventOptStitch).flatMap { + case Some(convoId) => + val lastClearedEventIdStitch = + dmConversationSource.getParticipantLastClearedEventId(convoId, viewerId, viewerId) + lastClearedEventIdStitch.flatMap { + case Some(lastClearedEventId) => Stitch(dmEventId <= lastClearedEventId) + case _ => + Stitch.False + } + case _ => Stitch.False + } + } + + private def dmEventOccurredBeforeJoinConversationEvent( + dmEventOptStitch: Stitch[Option[Event]], + dmEventId: DmEventId, + viewerId: UserId + ): Stitch[Boolean] = { + dmEventSource.getConversationId(dmEventOptStitch).flatMap { + case Some(convoId) => + val joinConversationEventIdStitch = + dmConversationSource + .getParticipantJoinConversationEventId(convoId, viewerId, viewerId) + joinConversationEventIdStitch.flatMap { + case Some(joinConversationEventId) => Stitch(dmEventId < joinConversationEventId) + case _ => Stitch.False + } + case _ => Stitch.False + } + } + + private def dmEventViewerIsDmConversationParticipant( + dmEventOptStitch: Stitch[Option[Event]], + viewerId: UserId + ): Stitch[Boolean] = { + dmEventSource.getConversationId(dmEventOptStitch).flatMap { + case Some(convoId) => + dmConversationFeatures.viewerIsDmConversationParticipant(convoId, Some(viewerId)) + case _ => Stitch.True + } + } + + private def dmEventIsDeleted( + dmEventOptStitch: Stitch[Option[Event]], + dmEventId: DmEventId + ): Stitch[Boolean] = + dmEventSource.getConversationId(dmEventOptStitch).flatMap { + case Some(convoId) => + dmConversationSource + .getDeleteInfo(convoId, dmEventId).rescue { + case e: java.lang.IllegalArgumentException => + Stitch.exception(InvalidDmEventFeatureException("Invalid conversation id")) + }.flatMap { + case Some(StoredDelete(None)) => Stitch.True + case _ => Stitch.False + } + case _ => Stitch.False + } + + private def dmEventIsHidden( + dmEventOptStitch: Stitch[Option[Event]], + dmEventId: DmEventId + ): Stitch[Boolean] = + dmEventSource.getConversationId(dmEventOptStitch).flatMap { + case Some(convoId) => + dmConversationSource + .getPerspectivalMessageInfo(convoId, dmEventId).rescue { + case e: java.lang.IllegalArgumentException => + Stitch.exception(InvalidDmEventFeatureException("Invalid conversation id")) + }.flatMap { + case Some(StoredPerspectivalMessageInfo(Some(hidden), _)) if hidden => + Stitch.True + case Some(StoredPerspectivalMessageInfo(_, Some(spamState))) + if spamState == PerspectivalSpamState.Spam => + Stitch.True + case _ => Stitch.False + } + case _ => Stitch.False + } + + private def viewerIsDmEventInitiatingUser( + dmEventOptStitch: Stitch[Option[Event]], + viewerId: UserId + ): Stitch[Boolean] = + Stitch + .join( + dmEventSource.getEventType(dmEventOptStitch), + dmEventSource.getInitiatingUserId(dmEventOptStitch)).flatMap { + case ( + Some( + EventType.TrustConversation | EventType.CsFeedbackSubmitted | + EventType.CsFeedbackDismissed | EventType.WelcomeMessageCreate | + EventType.JoinConversation), + Some(userId)) => + Stitch(viewerId == userId) + case ( + Some( + EventType.TrustConversation | EventType.CsFeedbackSubmitted | + EventType.CsFeedbackDismissed | EventType.WelcomeMessageCreate | + EventType.JoinConversation), + None) => + Stitch.exception(InvalidDmEventFeatureException("Initiating user id is missing")) + case (None, _) => + Stitch.exception(InvalidDmEventFeatureException("DmEvent type is missing")) + case _ => Stitch.True + } + + private def dmEventInOneToOneConversationWithUnavailableUser( + dmEventOptStitch: Stitch[Option[Event]], + viewerId: UserId + ): Stitch[Boolean] = + dmEventSource.getConversationId(dmEventOptStitch).flatMap { + case Some(conversationId) => + dmConversationFeatures + .dmConversationIsOneToOneConversation(conversationId, Some(viewerId)).flatMap { + isOneToOne => + if (isOneToOne) { + Stitch + .join( + dmConversationFeatures + .dmConversationHasSuspendedParticipant(conversationId, Some(viewerId)), + dmConversationFeatures + .dmConversationHasDeactivatedParticipant(conversationId, Some(viewerId)), + dmConversationFeatures + .dmConversationHasErasedParticipant(conversationId, Some(viewerId)) + ).flatMap { + case ( + convoParticipantIsSuspended, + convoParticipantIsDeactivated, + convoParticipantIsErased) => + Stitch.value( + convoParticipantIsSuspended || convoParticipantIsDeactivated || convoParticipantIsErased) + } + } else { + Stitch.False + } + } + case _ => Stitch.False + } + + private def dmEventInOneToOneConversation( + dmEventOptStitch: Stitch[Option[Event]], + viewerId: UserId + ): Stitch[Boolean] = + dmEventSource.getConversationId(dmEventOptStitch).flatMap { + case Some(conversationId) => + dmConversationFeatures + .dmConversationIsOneToOneConversation(conversationId, Some(viewerId)) + case _ => Stitch.False + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/media/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/builder/media/BUILD new file mode 100644 index 000000000..d74c9c5d0 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/media/BUILD @@ -0,0 +1,31 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "3rdparty/jvm/com/google/guava", + "mediaservices/commons/src/main/thrift:thrift-scala", + "mediaservices/media-util/src/main/scala", + "snowflake/src/main/scala/com/twitter/snowflake/id", + "src/thrift/com/twitter/context:twitter-context-scala", + "src/thrift/com/twitter/gizmoduck:user-thrift-scala", + "src/thrift/com/twitter/spam/rtf:safety-level-scala", + "src/thrift/com/twitter/spam/rtf:tweet-rtf-event-scala", + "src/thrift/com/twitter/tweetypie:service-scala", + "src/thrift/com/twitter/tweetypie:tweet-scala", + "stitch/stitch-core", + "tweetypie/src/scala/com/twitter/tweetypie/additionalfields", + "twitter-context/src/main/scala", + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/lib/src/main/resources/config", + "visibility/lib/src/main/scala/com/twitter/visibility", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/tweets", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/users", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/common/tweets", + "visibility/lib/src/main/scala/com/twitter/visibility/util", + "visibility/lib/src/main/thrift/com/twitter/visibility/safety_label_store:safety-label-store-scala", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/media/MediaFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/media/MediaFeatures.scala new file mode 100644 index 000000000..e20429846 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/media/MediaFeatures.scala @@ -0,0 +1,90 @@ +package com.twitter.visibility.builder.media + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.mediaservices.media_util.GenericMediaKey +import com.twitter.stitch.Stitch +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.common.MediaSafetyLabelMapSource +import com.twitter.visibility.features.MediaSafetyLabels +import com.twitter.visibility.models.MediaSafetyLabel +import com.twitter.visibility.models.MediaSafetyLabelType +import com.twitter.visibility.models.SafetyLabel + +class MediaFeatures( + mediaSafetyLabelMap: StratoMediaLabelMaps, + statsReceiver: StatsReceiver) { + + private[this] val scopedStatsReceiver = statsReceiver.scope("media_features") + + private[this] val requests = + scopedStatsReceiver + .counter("requests") + + private[this] val mediaSafetyLabelsStats = + scopedStatsReceiver + .scope(MediaSafetyLabels.name) + .counter("requests") + + private[this] val nonEmptyMediaStats = scopedStatsReceiver.scope("non_empty_media") + private[this] val nonEmptyMediaRequests = nonEmptyMediaStats.counter("requests") + private[this] val nonEmptyMediaKeysCount = nonEmptyMediaStats.counter("keys") + private[this] val nonEmptyMediaKeysLength = nonEmptyMediaStats.stat("keys_length") + + def forMediaKeys( + mediaKeys: Seq[GenericMediaKey], + ): FeatureMapBuilder => FeatureMapBuilder = { + requests.incr() + nonEmptyMediaKeysCount.incr(mediaKeys.size) + mediaSafetyLabelsStats.incr() + + if (mediaKeys.nonEmpty) { + nonEmptyMediaRequests.incr() + nonEmptyMediaKeysLength.add(mediaKeys.size) + } + + _.withFeature(MediaSafetyLabels, mediaSafetyLabelMap.forGenericMediaKeys(mediaKeys)) + } + + def forGenericMediaKey( + genericMediaKey: GenericMediaKey + ): FeatureMapBuilder => FeatureMapBuilder = { + requests.incr() + nonEmptyMediaKeysCount.incr() + mediaSafetyLabelsStats.incr() + nonEmptyMediaRequests.incr() + nonEmptyMediaKeysLength.add(1L) + + _.withFeature(MediaSafetyLabels, mediaSafetyLabelMap.forGenericMediaKey(genericMediaKey)) + } +} + +class StratoMediaLabelMaps(source: MediaSafetyLabelMapSource) { + + def forGenericMediaKeys( + mediaKeys: Seq[GenericMediaKey], + ): Stitch[Seq[MediaSafetyLabel]] = { + Stitch + .collect( + mediaKeys + .map(getFilteredSafetyLabels) + ).map(_.flatten) + } + + def forGenericMediaKey( + genericMediaKey: GenericMediaKey + ): Stitch[Seq[MediaSafetyLabel]] = { + getFilteredSafetyLabels(genericMediaKey) + } + + private def getFilteredSafetyLabels( + genericMediaKey: GenericMediaKey, + ): Stitch[Seq[MediaSafetyLabel]] = + source + .fetch(genericMediaKey).map(_.flatMap(_.labels.map { stratoSafetyLabelMap => + stratoSafetyLabelMap + .map(label => + MediaSafetyLabel( + MediaSafetyLabelType.fromThrift(label._1), + SafetyLabel.fromThrift(label._2))) + }).toSeq.flatten) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/media/MediaMetadataFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/media/MediaMetadataFeatures.scala new file mode 100644 index 000000000..97f35b27c --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/media/MediaMetadataFeatures.scala @@ -0,0 +1,79 @@ +package com.twitter.visibility.builder.media + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.mediaservices.media_util.GenericMediaKey +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.common.MediaMetadataSource +import com.twitter.visibility.features.HasDmcaMediaFeature +import com.twitter.visibility.features.MediaGeoRestrictionsAllowList +import com.twitter.visibility.features.MediaGeoRestrictionsDenyList +import com.twitter.visibility.features.AuthorId + +class MediaMetadataFeatures( + mediaMetadataSource: MediaMetadataSource, + statsReceiver: StatsReceiver) { + + private[this] val scopedStatsReceiver = statsReceiver.scope("media_metadata_features") + private[this] val requests = scopedStatsReceiver.counter("requests") + + private[this] val hasDmcaMedia = + scopedStatsReceiver.scope(HasDmcaMediaFeature.name).counter("requests") + private[this] val mediaGeoAllowList = + scopedStatsReceiver.scope(MediaGeoRestrictionsAllowList.name).counter("requests") + private[this] val mediaGeoDenyList = + scopedStatsReceiver.scope(MediaGeoRestrictionsDenyList.name).counter("requests") + private[this] val uploaderId = + scopedStatsReceiver.scope(AuthorId.name).counter("requests") + + def forGenericMediaKey( + genericMediaKey: GenericMediaKey + ): FeatureMapBuilder => FeatureMapBuilder = { featureMapBuilder => + requests.incr() + + featureMapBuilder.withFeature( + HasDmcaMediaFeature, + mediaIsDmca(genericMediaKey) + ) + + featureMapBuilder.withFeature( + MediaGeoRestrictionsAllowList, + geoRestrictionsAllowList(genericMediaKey) + ) + + featureMapBuilder.withFeature( + MediaGeoRestrictionsDenyList, + geoRestrictionsDenyList(genericMediaKey) + ) + + featureMapBuilder.withFeature( + AuthorId, + mediaUploaderId(genericMediaKey) + ) + } + + private def mediaIsDmca(genericMediaKey: GenericMediaKey) = { + hasDmcaMedia.incr() + mediaMetadataSource.getMediaIsDmca(genericMediaKey) + } + + private def geoRestrictionsAllowList(genericMediaKey: GenericMediaKey) = { + mediaGeoAllowList.incr() + mediaMetadataSource.getGeoRestrictionsAllowList(genericMediaKey).map { allowListOpt => + allowListOpt.getOrElse(Nil) + } + } + + private def geoRestrictionsDenyList(genericMediaKey: GenericMediaKey) = { + mediaGeoDenyList.incr() + mediaMetadataSource.getGeoRestrictionsDenyList(genericMediaKey).map { denyListOpt => + denyListOpt.getOrElse(Nil) + } + } + + private def mediaUploaderId(genericMediaKey: GenericMediaKey) = { + uploaderId.incr() + mediaMetadataSource.getMediaUploaderId(genericMediaKey).map { uploaderIdOpt => + uploaderIdOpt.map(Set(_)).getOrElse(Set.empty) + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/spaces/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/builder/spaces/BUILD new file mode 100644 index 000000000..ec82612ef --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/spaces/BUILD @@ -0,0 +1,25 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "3rdparty/jvm/com/google/guava", + "snowflake/src/main/scala/com/twitter/snowflake/id", + "src/thrift/com/twitter/context:twitter-context-scala", + "src/thrift/com/twitter/gizmoduck:user-thrift-scala", + "src/thrift/com/twitter/spam/rtf:safety-level-scala", + "src/thrift/com/twitter/spam/rtf:tweet-rtf-event-scala", + "stitch/stitch-core", + "twitter-context/src/main/scala", + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/lib/src/main/resources/config", + "visibility/lib/src/main/scala/com/twitter/visibility", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/common", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/users", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/util", + "visibility/lib/src/main/thrift/com/twitter/visibility/safety_label_store:safety-label-store-scala", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/spaces/SpaceFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/spaces/SpaceFeatures.scala new file mode 100644 index 000000000..6522f8e65 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/spaces/SpaceFeatures.scala @@ -0,0 +1,131 @@ +package com.twitter.visibility.builder.spaces + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.gizmoduck.thriftscala.Label +import com.twitter.gizmoduck.thriftscala.MuteSurface +import com.twitter.stitch.Stitch +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.builder.common.MutedKeywordFeatures +import com.twitter.visibility.builder.users.AuthorFeatures +import com.twitter.visibility.builder.users.RelationshipFeatures +import com.twitter.visibility.common.AudioSpaceSource +import com.twitter.visibility.common.SpaceId +import com.twitter.visibility.common.SpaceSafetyLabelMapSource +import com.twitter.visibility.common.UserId +import com.twitter.visibility.features._ +import com.twitter.visibility.models.{MutedKeyword => VfMutedKeyword} +import com.twitter.visibility.models.SafetyLabel +import com.twitter.visibility.models.SpaceSafetyLabel +import com.twitter.visibility.models.SpaceSafetyLabelType + +class SpaceFeatures( + spaceSafetyLabelMap: StratoSpaceLabelMaps, + authorFeatures: AuthorFeatures, + relationshipFeatures: RelationshipFeatures, + mutedKeywordFeatures: MutedKeywordFeatures, + audioSpaceSource: AudioSpaceSource) { + + def forSpaceAndAuthorIds( + spaceId: SpaceId, + viewerId: Option[UserId], + authorIds: Option[Seq[UserId]] + ): FeatureMapBuilder => FeatureMapBuilder = { + + _.withFeature(SpaceSafetyLabels, spaceSafetyLabelMap.forSpaceId(spaceId)) + .withFeature(AuthorId, getSpaceAuthors(spaceId, authorIds).map(_.toSet)) + .withFeature(AuthorUserLabels, allSpaceAuthorLabels(spaceId, authorIds)) + .withFeature(ViewerFollowsAuthor, viewerFollowsAnySpaceAuthor(spaceId, authorIds, viewerId)) + .withFeature(ViewerMutesAuthor, viewerMutesAnySpaceAuthor(spaceId, authorIds, viewerId)) + .withFeature(ViewerBlocksAuthor, viewerBlocksAnySpaceAuthor(spaceId, authorIds, viewerId)) + .withFeature(AuthorBlocksViewer, anySpaceAuthorBlocksViewer(spaceId, authorIds, viewerId)) + .withFeature( + ViewerMutesKeywordInSpaceTitleForNotifications, + titleContainsMutedKeyword( + audioSpaceSource.getSpaceTitle(spaceId), + audioSpaceSource.getSpaceLanguage(spaceId), + viewerId) + ) + } + + def titleContainsMutedKeyword( + titleOptStitch: Stitch[Option[String]], + languageOptStitch: Stitch[Option[String]], + viewerId: Option[UserId], + ): Stitch[VfMutedKeyword] = { + titleOptStitch.flatMap { + case None => Stitch.value(VfMutedKeyword(None)) + case Some(spaceTitle) => + languageOptStitch.flatMap { languageOpt => + mutedKeywordFeatures.spaceTitleContainsMutedKeyword( + spaceTitle, + languageOpt, + mutedKeywordFeatures.allMutedKeywords(viewerId), + MuteSurface.Notifications) + } + } + } + + def getSpaceAuthors( + spaceId: SpaceId, + authorIdsFromRequest: Option[Seq[UserId]] + ): Stitch[Seq[UserId]] = { + authorIdsFromRequest match { + case Some(authorIds) => Stitch.apply(authorIds) + case _ => audioSpaceSource.getAdminIds(spaceId) + } + } + + def allSpaceAuthorLabels( + spaceId: SpaceId, + authorIdsFromRequest: Option[Seq[UserId]] + ): Stitch[Seq[Label]] = { + getSpaceAuthors(spaceId, authorIdsFromRequest) + .flatMap(authorIds => + Stitch.collect(authorIds.map(authorId => authorFeatures.authorUserLabels(authorId)))).map( + _.flatten) + } + + def viewerMutesAnySpaceAuthor( + spaceId: SpaceId, + authorIdsFromRequest: Option[Seq[UserId]], + viewerId: Option[UserId] + ): Stitch[Boolean] = { + getSpaceAuthors(spaceId, authorIdsFromRequest) + .flatMap(authorIds => + Stitch.collect(authorIds.map(authorId => + relationshipFeatures.viewerMutesAuthor(authorId, viewerId)))).map(_.contains(true)) + } + + def anySpaceAuthorBlocksViewer( + spaceId: SpaceId, + authorIdsFromRequest: Option[Seq[UserId]], + viewerId: Option[UserId] + ): Stitch[Boolean] = { + getSpaceAuthors(spaceId, authorIdsFromRequest) + .flatMap(authorIds => + Stitch.collect(authorIds.map(authorId => + relationshipFeatures.authorBlocksViewer(authorId, viewerId)))).map(_.contains(true)) + } +} + +class StratoSpaceLabelMaps( + spaceSafetyLabelSource: SpaceSafetyLabelMapSource, + statsReceiver: StatsReceiver) { + + private[this] val scopedStatsReceiver = statsReceiver.scope("space_features") + private[this] val spaceSafetyLabelsStats = + scopedStatsReceiver.scope(SpaceSafetyLabels.name).counter("requests") + + def forSpaceId( + spaceId: SpaceId, + ): Stitch[Seq[SpaceSafetyLabel]] = { + spaceSafetyLabelSource + .fetch(spaceId).map(_.flatMap(_.labels.map { stratoSafetyLabelMap => + stratoSafetyLabelMap + .map(label => + SpaceSafetyLabel( + SpaceSafetyLabelType.fromThrift(label._1), + SafetyLabel.fromThrift(label._2))) + }).toSeq.flatten).ensure(spaceSafetyLabelsStats.incr) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/BUILD new file mode 100644 index 000000000..bde1871b4 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/BUILD @@ -0,0 +1,38 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "3rdparty/jvm/com/google/guava", + "communities/thrift/src/main/thrift/com/twitter/communities:thrift-scala", + "communities/thrift/src/main/thrift/com/twitter/communities/moderation:thrift-scala", + "communities/thrift/src/main/thrift/com/twitter/communities/visibility:thrift-scala", + "escherbird/src/thrift/com/twitter/escherbird/softintervention:softintervention_thrift-scala", + "mediaservices/media-util/src/main/scala", + "notificationservice/common/src/main/scala/com/twitter/notificationservice/model:alias", + "notificationservice/common/src/main/scala/com/twitter/notificationservice/model/notification", + "snowflake/src/main/scala/com/twitter/snowflake/id", + "src/thrift/com/twitter/context:twitter-context-scala", + "src/thrift/com/twitter/escherbird/common:common-scala", + "src/thrift/com/twitter/gizmoduck:user-thrift-scala", + "src/thrift/com/twitter/search/common:constants-scala", + "src/thrift/com/twitter/spam/rtf:safety-level-scala", + "src/thrift/com/twitter/spam/rtf:tweet-rtf-event-scala", + "src/thrift/com/twitter/tweetypie:tweet-scala", + "stitch/stitch-core", + # "tweetypie/src/scala/com/twitter/tweetypie/additionalfields", + "twitter-context/src/main/scala", + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/common/src/main/scala/com/twitter/visibility/common/stitch", + "visibility/lib/src/main/resources/config", + "visibility/lib/src/main/scala/com/twitter/visibility", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/users", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/common/blender", + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/common/search", + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/common/tweets", + "visibility/lib/src/main/thrift/com/twitter/visibility/strato:vf-strato-scala", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/BlenderContextFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/BlenderContextFeatures.scala new file mode 100644 index 000000000..931574567 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/BlenderContextFeatures.scala @@ -0,0 +1,45 @@ +package com.twitter.visibility.builder.tweets + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.search.common.constants.thriftscala.ThriftQuerySource +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.features.SearchCandidateCount +import com.twitter.visibility.features.SearchQueryHasUser +import com.twitter.visibility.features.SearchQuerySource +import com.twitter.visibility.features.SearchResultsPageNumber +import com.twitter.visibility.interfaces.common.blender.BlenderVFRequestContext + +@Deprecated +class BlenderContextFeatures( + statsReceiver: StatsReceiver) { + private[this] val scopedStatsReceiver = statsReceiver.scope("blender_context_features") + private[this] val requests = scopedStatsReceiver.counter("requests") + private[this] val searchResultsPageNumber = + scopedStatsReceiver.scope(SearchResultsPageNumber.name).counter("requests") + private[this] val searchCandidateCount = + scopedStatsReceiver.scope(SearchCandidateCount.name).counter("requests") + private[this] val searchQuerySource = + scopedStatsReceiver.scope(SearchQuerySource.name).counter("requests") + private[this] val searchQueryHasUser = + scopedStatsReceiver.scope(SearchQueryHasUser.name).counter("requests") + + def forBlenderContext( + blenderContext: BlenderVFRequestContext + ): FeatureMapBuilder => FeatureMapBuilder = { + requests.incr() + searchResultsPageNumber.incr() + searchCandidateCount.incr() + searchQuerySource.incr() + searchQueryHasUser.incr() + + _.withConstantFeature(SearchResultsPageNumber, blenderContext.resultsPageNumber) + .withConstantFeature(SearchCandidateCount, blenderContext.candidateCount) + .withConstantFeature( + SearchQuerySource, + blenderContext.querySourceOption match { + case Some(querySource) => querySource + case _ => ThriftQuerySource.Unknown + }) + .withConstantFeature(SearchQueryHasUser, blenderContext.queryHasUser) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/CommunityNotificationFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/CommunityNotificationFeatures.scala new file mode 100644 index 000000000..a4588620b --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/CommunityNotificationFeatures.scala @@ -0,0 +1,64 @@ +package com.twitter.visibility.builder.tweets + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.notificationservice.model.notification.ActivityNotification +import com.twitter.notificationservice.model.notification.MentionNotification +import com.twitter.notificationservice.model.notification.MentionQuoteNotification +import com.twitter.notificationservice.model.notification.Notification +import com.twitter.notificationservice.model.notification.QuoteTweetNotification +import com.twitter.servo.util.Gate +import com.twitter.stitch.Stitch +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.common.TweetSource +import com.twitter.visibility.features.NotificationIsOnCommunityTweet +import com.twitter.visibility.models.CommunityTweet + +object CommunityNotificationFeatures { + def ForNonCommunityTweetNotification: FeatureMapBuilder => FeatureMapBuilder = { + _.withConstantFeature(NotificationIsOnCommunityTweet, false) + } +} + +class CommunityNotificationFeatures( + tweetSource: TweetSource, + enableCommunityTweetHydration: Gate[Long], + statsReceiver: StatsReceiver) { + + private[this] val scopedStatsReceiver = statsReceiver.scope("community_notification_features") + private[this] val requestsCounter = scopedStatsReceiver.counter("requests") + private[this] val hydrationsCounter = scopedStatsReceiver.counter("hydrations") + private[this] val notificationIsOnCommunityTweetCounter = + scopedStatsReceiver.scope(NotificationIsOnCommunityTweet.name).counter("true") + private[this] val notificationIsNotOnCommunityTweetCounter = + scopedStatsReceiver.scope(NotificationIsOnCommunityTweet.name).counter("false") + + def forNotification(notification: Notification): FeatureMapBuilder => FeatureMapBuilder = { + requestsCounter.incr() + val isCommunityTweetResult = getTweetIdOption(notification) match { + case Some(tweetId) if enableCommunityTweetHydration(notification.target) => + hydrationsCounter.incr() + tweetSource + .getTweet(tweetId) + .map { + case Some(tweet) if CommunityTweet(tweet).nonEmpty => + notificationIsOnCommunityTweetCounter.incr() + true + case _ => + notificationIsNotOnCommunityTweetCounter.incr() + false + } + case _ => Stitch.False + } + _.withFeature(NotificationIsOnCommunityTweet, isCommunityTweetResult) + } + + private[this] def getTweetIdOption(notification: Notification): Option[Long] = { + notification match { + case n: MentionNotification => Some(n.mentioningTweetId) + case n: MentionQuoteNotification => Some(n.mentioningTweetId) + case n: QuoteTweetNotification => Some(n.quotedTweetId) + case n: ActivityNotification[_] if n.visibilityTweets.contains(n.objectId) => Some(n.objectId) + case _ => None + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/CommunityTweetFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/CommunityTweetFeatures.scala new file mode 100644 index 000000000..8a2c752f6 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/CommunityTweetFeatures.scala @@ -0,0 +1,70 @@ +package com.twitter.visibility.builder.tweets + +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.features.CommunityTweetAuthorIsRemoved +import com.twitter.visibility.features.CommunityTweetCommunityNotFound +import com.twitter.visibility.features.CommunityTweetCommunityDeleted +import com.twitter.visibility.features.CommunityTweetCommunitySuspended +import com.twitter.visibility.features.CommunityTweetCommunityVisible +import com.twitter.visibility.features.CommunityTweetIsHidden +import com.twitter.visibility.features.TweetIsCommunityTweet +import com.twitter.visibility.features.ViewerIsCommunityAdmin +import com.twitter.visibility.features.ViewerIsCommunityMember +import com.twitter.visibility.features.ViewerIsCommunityModerator +import com.twitter.visibility.features.ViewerIsInternalCommunitiesAdmin +import com.twitter.visibility.models.CommunityTweet +import com.twitter.visibility.models.ViewerContext + +trait CommunityTweetFeatures { + + def forTweet( + tweet: Tweet, + viewerContext: ViewerContext + ): FeatureMapBuilder => FeatureMapBuilder + + def forTweetOnly(tweet: Tweet): FeatureMapBuilder => FeatureMapBuilder = { + _.withConstantFeature( + TweetIsCommunityTweet, + CommunityTweet(tweet).isDefined + ) + } + + protected def forNonCommunityTweet(): FeatureMapBuilder => FeatureMapBuilder = { builder => + builder + .withConstantFeature( + TweetIsCommunityTweet, + false + ).withConstantFeature( + CommunityTweetCommunityNotFound, + false + ).withConstantFeature( + CommunityTweetCommunitySuspended, + false + ).withConstantFeature( + CommunityTweetCommunityDeleted, + false + ).withConstantFeature( + CommunityTweetCommunityVisible, + false + ).withConstantFeature( + ViewerIsInternalCommunitiesAdmin, + false + ).withConstantFeature( + ViewerIsCommunityAdmin, + false + ).withConstantFeature( + ViewerIsCommunityModerator, + false + ).withConstantFeature( + ViewerIsCommunityMember, + false + ).withConstantFeature( + CommunityTweetIsHidden, + false + ).withConstantFeature( + CommunityTweetAuthorIsRemoved, + false + ) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/CommunityTweetFeaturesPartitioned.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/CommunityTweetFeaturesPartitioned.scala new file mode 100644 index 000000000..695f724eb --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/CommunityTweetFeaturesPartitioned.scala @@ -0,0 +1,26 @@ +package com.twitter.visibility.builder.tweets + +import com.twitter.servo.util.Gate +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.models.ViewerContext + +class CommunityTweetFeaturesPartitioned( + a: CommunityTweetFeatures, + b: CommunityTweetFeatures, + bEnabled: Gate[Unit], +) extends CommunityTweetFeatures { + override def forTweet( + tweet: Tweet, + viewerContext: ViewerContext + ): FeatureMapBuilder => FeatureMapBuilder = + bEnabled.pick( + b.forTweet(tweet, viewerContext), + a.forTweet(tweet, viewerContext), + ) + + override def forTweetOnly(tweet: Tweet): FeatureMapBuilder => FeatureMapBuilder = bEnabled.pick( + b.forTweetOnly(tweet), + a.forTweetOnly(tweet), + ) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/CommunityTweetFeaturesV2.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/CommunityTweetFeaturesV2.scala new file mode 100644 index 000000000..407b5308c --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/CommunityTweetFeaturesV2.scala @@ -0,0 +1,129 @@ +package com.twitter.visibility.builder.tweets + +import com.twitter.communities.moderation.thriftscala.CommunityTweetModerationState +import com.twitter.communities.moderation.thriftscala.CommunityUserModerationState +import com.twitter.communities.visibility.thriftscala.CommunityVisibilityFeatures +import com.twitter.communities.visibility.thriftscala.CommunityVisibilityFeaturesV1 +import com.twitter.communities.visibility.thriftscala.CommunityVisibilityResult +import com.twitter.stitch.Stitch +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.common.CommunitiesSource +import com.twitter.visibility.features.CommunityTweetAuthorIsRemoved +import com.twitter.visibility.features.CommunityTweetCommunityNotFound +import com.twitter.visibility.features.CommunityTweetCommunityDeleted +import com.twitter.visibility.features.CommunityTweetCommunitySuspended +import com.twitter.visibility.features.CommunityTweetCommunityVisible +import com.twitter.visibility.features.CommunityTweetIsHidden +import com.twitter.visibility.features.TweetIsCommunityTweet +import com.twitter.visibility.features.ViewerIsCommunityAdmin +import com.twitter.visibility.features.ViewerIsCommunityMember +import com.twitter.visibility.features.ViewerIsCommunityModerator +import com.twitter.visibility.features.ViewerIsInternalCommunitiesAdmin +import com.twitter.visibility.models.CommunityTweet +import com.twitter.visibility.models.ViewerContext + +class CommunityTweetFeaturesV2(communitiesSource: CommunitiesSource) + extends CommunityTweetFeatures { + private[this] def forCommunityTweet( + communityTweet: CommunityTweet + ): FeatureMapBuilder => FeatureMapBuilder = { builder: FeatureMapBuilder => + { + val communityVisibilityFeaturesStitch = + communitiesSource.getCommunityVisibilityFeatures(communityTweet.communityId) + val communityTweetModerationStateStitch = + communitiesSource.getTweetModerationState(communityTweet.tweet.id) + val communityTweetAuthorModerationStateStitch = + communitiesSource.getUserModerationState( + communityTweet.authorId, + communityTweet.communityId + ) + + def getFlagFromFeatures(f: CommunityVisibilityFeaturesV1 => Boolean): Stitch[Boolean] = + communityVisibilityFeaturesStitch.map { + case Some(CommunityVisibilityFeatures.V1(v1)) => f(v1) + case _ => false + } + + def getFlagFromCommunityVisibilityResult( + f: CommunityVisibilityResult => Boolean + ): Stitch[Boolean] = getFlagFromFeatures { v => + f(v.communityVisibilityResult) + } + + builder + .withConstantFeature( + TweetIsCommunityTweet, + true + ) + .withFeature( + CommunityTweetCommunityNotFound, + getFlagFromCommunityVisibilityResult { + case CommunityVisibilityResult.NotFound => true + case _ => false + } + ) + .withFeature( + CommunityTweetCommunitySuspended, + getFlagFromCommunityVisibilityResult { + case CommunityVisibilityResult.Suspended => true + case _ => false + } + ) + .withFeature( + CommunityTweetCommunityDeleted, + getFlagFromCommunityVisibilityResult { + case CommunityVisibilityResult.Deleted => true + case _ => false + } + ) + .withFeature( + CommunityTweetCommunityVisible, + getFlagFromCommunityVisibilityResult { + case CommunityVisibilityResult.Visible => true + case _ => false + } + ) + .withFeature( + ViewerIsInternalCommunitiesAdmin, + getFlagFromFeatures { _.viewerIsInternalAdmin } + ) + .withFeature( + ViewerIsCommunityAdmin, + getFlagFromFeatures { _.viewerIsCommunityAdmin } + ) + .withFeature( + ViewerIsCommunityModerator, + getFlagFromFeatures { _.viewerIsCommunityModerator } + ) + .withFeature( + ViewerIsCommunityMember, + getFlagFromFeatures { _.viewerIsCommunityMember } + ) + .withFeature( + CommunityTweetIsHidden, + communityTweetModerationStateStitch.map { + case Some(CommunityTweetModerationState.Hidden(_)) => true + case _ => false + } + ) + .withFeature( + CommunityTweetAuthorIsRemoved, + communityTweetAuthorModerationStateStitch.map { + case Some(CommunityUserModerationState.Removed(_)) => true + case _ => false + } + ) + } + } + + def forTweet( + tweet: Tweet, + viewerContext: ViewerContext + ): FeatureMapBuilder => FeatureMapBuilder = { + CommunityTweet(tweet) match { + case None => forNonCommunityTweet() + case Some(communityTweet) => forCommunityTweet(communityTweet) + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/ConversationControlFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/ConversationControlFeatures.scala new file mode 100644 index 000000000..3f5ad7281 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/ConversationControlFeatures.scala @@ -0,0 +1,178 @@ +package com.twitter.visibility.builder.tweets + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.stitch.Stitch +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.tweetypie.thriftscala.ConversationControl +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.builder.users.RelationshipFeatures +import com.twitter.visibility.common.InvitedToConversationRepo +import com.twitter.visibility.features.ConversationRootAuthorFollowsViewer +import com.twitter.visibility.features.TweetConversationViewerIsInvited +import com.twitter.visibility.features.TweetConversationViewerIsInvitedViaReplyMention +import com.twitter.visibility.features.TweetConversationViewerIsRootAuthor +import com.twitter.visibility.features.TweetHasByInvitationConversationControl +import com.twitter.visibility.features.TweetHasCommunityConversationControl +import com.twitter.visibility.features.TweetHasFollowersConversationControl +import com.twitter.visibility.features.ViewerFollowsConversationRootAuthor + +class ConversationControlFeatures( + relationshipFeatures: RelationshipFeatures, + isInvitedToConversationRepository: InvitedToConversationRepo, + statsReceiver: StatsReceiver) { + + private[this] val scopedStatsReceiver = statsReceiver.scope("conversation_control_features") + + private[this] val requests = scopedStatsReceiver.counter("requests") + + private[this] val tweetCommunityConversationRequest = + scopedStatsReceiver.scope(TweetHasCommunityConversationControl.name).counter("requests") + private[this] val tweetByInvitationConversationRequest = + scopedStatsReceiver.scope(TweetHasByInvitationConversationControl.name).counter("requests") + private[this] val tweetFollowersConversationRequest = + scopedStatsReceiver.scope(TweetHasFollowersConversationControl.name).counter("requests") + private[this] val rootAuthorFollowsViewer = + scopedStatsReceiver.scope(ConversationRootAuthorFollowsViewer.name).counter("requests") + private[this] val viewerFollowsRootAuthor = + scopedStatsReceiver.scope(ViewerFollowsConversationRootAuthor.name).counter("requests") + + def isCommunityConversation(conversationControl: Option[ConversationControl]): Boolean = + conversationControl + .collect { + case _: ConversationControl.Community => + tweetCommunityConversationRequest.incr() + true + }.getOrElse(false) + + def isByInvitationConversation(conversationControl: Option[ConversationControl]): Boolean = + conversationControl + .collect { + case _: ConversationControl.ByInvitation => + tweetByInvitationConversationRequest.incr() + true + }.getOrElse(false) + + def isFollowersConversation(conversationControl: Option[ConversationControl]): Boolean = + conversationControl + .collect { + case _: ConversationControl.Followers => + tweetFollowersConversationRequest.incr() + true + }.getOrElse(false) + + def conversationRootAuthorId( + conversationControl: Option[ConversationControl] + ): Option[Long] = + conversationControl match { + case Some(ConversationControl.Community(community)) => + Some(community.conversationTweetAuthorId) + case Some(ConversationControl.ByInvitation(byInvitation)) => + Some(byInvitation.conversationTweetAuthorId) + case Some(ConversationControl.Followers(followers)) => + Some(followers.conversationTweetAuthorId) + case _ => None + } + + def viewerIsRootAuthor( + conversationControl: Option[ConversationControl], + viewerIdOpt: Option[Long] + ): Boolean = + (conversationRootAuthorId(conversationControl), viewerIdOpt) match { + case (Some(rootAuthorId), Some(viewerId)) if rootAuthorId == viewerId => true + case _ => false + } + + def viewerIsInvited( + conversationControl: Option[ConversationControl], + viewerId: Option[Long] + ): Boolean = { + val invitedUserIds = conversationControl match { + case Some(ConversationControl.Community(community)) => + community.invitedUserIds + case Some(ConversationControl.ByInvitation(byInvitation)) => + byInvitation.invitedUserIds + case Some(ConversationControl.Followers(followers)) => + followers.invitedUserIds + case _ => Seq() + } + + viewerId.exists(invitedUserIds.contains(_)) + } + + def conversationAuthorFollows( + conversationControl: Option[ConversationControl], + viewerId: Option[Long] + ): Stitch[Boolean] = { + val conversationAuthorId = conversationControl.collect { + case ConversationControl.Community(community) => + community.conversationTweetAuthorId + } + + conversationAuthorId match { + case Some(authorId) => + rootAuthorFollowsViewer.incr() + relationshipFeatures.authorFollowsViewer(authorId, viewerId) + case None => + Stitch.False + } + } + + def followsConversationAuthor( + conversationControl: Option[ConversationControl], + viewerId: Option[Long] + ): Stitch[Boolean] = { + val conversationAuthorId = conversationControl.collect { + case ConversationControl.Followers(followers) => + followers.conversationTweetAuthorId + } + + conversationAuthorId match { + case Some(authorId) => + viewerFollowsRootAuthor.incr() + relationshipFeatures.viewerFollowsAuthor(authorId, viewerId) + case None => + Stitch.False + } + } + + def viewerIsInvitedViaReplyMention( + tweet: Tweet, + viewerIdOpt: Option[Long] + ): Stitch[Boolean] = { + val conversationIdOpt: Option[Long] = tweet.conversationControl match { + case Some(ConversationControl.Community(community)) + if community.inviteViaMention.contains(true) => + tweet.coreData.flatMap(_.conversationId) + case Some(ConversationControl.ByInvitation(invitation)) + if invitation.inviteViaMention.contains(true) => + tweet.coreData.flatMap(_.conversationId) + case Some(ConversationControl.Followers(followers)) + if followers.inviteViaMention.contains(true) => + tweet.coreData.flatMap(_.conversationId) + case _ => None + } + + (conversationIdOpt, viewerIdOpt) match { + case (Some(conversationId), Some(viewerId)) => + isInvitedToConversationRepository(conversationId, viewerId) + case _ => Stitch.False + } + } + + def forTweet(tweet: Tweet, viewerId: Option[Long]): FeatureMapBuilder => FeatureMapBuilder = { + requests.incr() + val cc = tweet.conversationControl + + _.withConstantFeature(TweetHasCommunityConversationControl, isCommunityConversation(cc)) + .withConstantFeature(TweetHasByInvitationConversationControl, isByInvitationConversation(cc)) + .withConstantFeature(TweetHasFollowersConversationControl, isFollowersConversation(cc)) + .withConstantFeature(TweetConversationViewerIsRootAuthor, viewerIsRootAuthor(cc, viewerId)) + .withConstantFeature(TweetConversationViewerIsInvited, viewerIsInvited(cc, viewerId)) + .withFeature(ConversationRootAuthorFollowsViewer, conversationAuthorFollows(cc, viewerId)) + .withFeature(ViewerFollowsConversationRootAuthor, followsConversationAuthor(cc, viewerId)) + .withFeature( + TweetConversationViewerIsInvitedViaReplyMention, + viewerIsInvitedViaReplyMention(tweet, viewerId)) + + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/EditTweetFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/EditTweetFeatures.scala new file mode 100644 index 000000000..224f4a9a4 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/EditTweetFeatures.scala @@ -0,0 +1,71 @@ +package com.twitter.visibility.builder.tweets + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.tweetypie.thriftscala.EditControl +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.features.TweetIsEditTweet +import com.twitter.visibility.features.TweetIsInitialTweet +import com.twitter.visibility.features.TweetIsLatestTweet +import com.twitter.visibility.features.TweetIsStaleTweet + +class EditTweetFeatures( + statsReceiver: StatsReceiver) { + + private[this] val scopedStatsReceiver = statsReceiver.scope("edit_tweet_features") + private[this] val tweetIsEditTweet = + scopedStatsReceiver.scope(TweetIsEditTweet.name).counter("requests") + private[this] val tweetIsStaleTweet = + scopedStatsReceiver.scope(TweetIsStaleTweet.name).counter("requests") + private[this] val tweetIsLatestTweet = + scopedStatsReceiver.scope(TweetIsLatestTweet.name).counter("requests") + private[this] val tweetIsInitialTweet = + scopedStatsReceiver.scope(TweetIsInitialTweet.name).counter("requests") + + def forTweet( + tweet: Tweet + ): FeatureMapBuilder => FeatureMapBuilder = { + _.withConstantFeature(TweetIsEditTweet, tweetIsEditTweet(tweet)) + .withConstantFeature(TweetIsStaleTweet, tweetIsStaleTweet(tweet)) + .withConstantFeature(TweetIsLatestTweet, tweetIsLatestTweet(tweet)) + .withConstantFeature(TweetIsInitialTweet, tweetIsInitialTweet(tweet)) + } + + def tweetIsStaleTweet(tweet: Tweet, incrementMetric: Boolean = true): Boolean = { + if (incrementMetric) tweetIsStaleTweet.incr() + + tweet.editControl match { + case None => false + case Some(ec) => + ec match { + case eci: EditControl.Initial => eci.initial.editTweetIds.last != tweet.id + case ece: EditControl.Edit => + ece.edit.editControlInitial.exists(_.editTweetIds.last != tweet.id) + case _ => false + } + } + } + + def tweetIsEditTweet(tweet: Tweet, incrementMetric: Boolean = true): Boolean = { + if (incrementMetric) tweetIsEditTweet.incr() + + tweet.editControl match { + case None => false + case Some(ec) => + ec match { + case _: EditControl.Initial => false + case _ => true + } + } + } + + def tweetIsLatestTweet(tweet: Tweet): Boolean = { + tweetIsLatestTweet.incr() + !tweetIsStaleTweet(tweet = tweet, incrementMetric = false) + } + + def tweetIsInitialTweet(tweet: Tweet): Boolean = { + tweetIsInitialTweet.incr() + !tweetIsEditTweet(tweet = tweet, incrementMetric = false) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/ExclusiveTweetFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/ExclusiveTweetFeatures.scala new file mode 100644 index 000000000..94aee6430 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/ExclusiveTweetFeatures.scala @@ -0,0 +1,65 @@ +package com.twitter.visibility.builder.tweets + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.stitch.Stitch +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.builder.users.ViewerVerbsAuthor +import com.twitter.visibility.common.UserRelationshipSource +import com.twitter.visibility.features.TweetIsExclusiveTweet +import com.twitter.visibility.features.ViewerIsExclusiveTweetRootAuthor +import com.twitter.visibility.features.ViewerSuperFollowsExclusiveTweetRootAuthor +import com.twitter.visibility.models.ViewerContext + +class ExclusiveTweetFeatures( + userRelationshipSource: UserRelationshipSource, + statsReceiver: StatsReceiver) { + + private[this] val scopedStatsReceiver = statsReceiver.scope("exclusive_tweet_features") + private[this] val viewerSuperFollowsAuthor = + scopedStatsReceiver.scope(ViewerSuperFollowsExclusiveTweetRootAuthor.name).counter("requests") + + def rootAuthorId(tweet: Tweet): Option[Long] = + tweet.exclusiveTweetControl.map(_.conversationAuthorId) + + def viewerIsRootAuthor( + tweet: Tweet, + viewerIdOpt: Option[Long] + ): Boolean = + (rootAuthorId(tweet), viewerIdOpt) match { + case (Some(rootAuthorId), Some(viewerId)) if rootAuthorId == viewerId => true + case _ => false + } + + def viewerSuperFollowsRootAuthor( + tweet: Tweet, + viewerId: Option[Long] + ): Stitch[Boolean] = + rootAuthorId(tweet) match { + case Some(authorId) => + ViewerVerbsAuthor( + authorId, + viewerId, + userRelationshipSource.superFollows, + viewerSuperFollowsAuthor) + case None => + Stitch.False + } + + def forTweet( + tweet: Tweet, + viewerContext: ViewerContext + ): FeatureMapBuilder => FeatureMapBuilder = { + val viewerId = viewerContext.userId + + _.withConstantFeature(TweetIsExclusiveTweet, tweet.exclusiveTweetControl.isDefined) + .withConstantFeature(ViewerIsExclusiveTweetRootAuthor, viewerIsRootAuthor(tweet, viewerId)) + .withFeature( + ViewerSuperFollowsExclusiveTweetRootAuthor, + viewerSuperFollowsRootAuthor(tweet, viewerId)) + } + + def forTweetOnly(tweet: Tweet): FeatureMapBuilder => FeatureMapBuilder = { + _.withConstantFeature(TweetIsExclusiveTweet, tweet.exclusiveTweetControl.isDefined) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/FosnrPefetchedLabelsRelationshipFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/FosnrPefetchedLabelsRelationshipFeatures.scala new file mode 100644 index 000000000..f9a0445f9 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/FosnrPefetchedLabelsRelationshipFeatures.scala @@ -0,0 +1,81 @@ +package com.twitter.visibility.builder.tweets + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.stitch.Stitch +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.builder.users.ViewerVerbsAuthor +import com.twitter.visibility.common.UserId +import com.twitter.visibility.common.UserRelationshipSource +import com.twitter.visibility.features._ +import com.twitter.visibility.models.TweetSafetyLabel +import com.twitter.visibility.models.ViolationLevel + +class FosnrPefetchedLabelsRelationshipFeatures( + userRelationshipSource: UserRelationshipSource, + statsReceiver: StatsReceiver) { + + private[this] val scopedStatsReceiver = + statsReceiver.scope("fonsr_prefetched_relationship_features") + + private[this] val requests = scopedStatsReceiver.counter("requests") + + private[this] val viewerFollowsAuthorOfViolatingTweet = + scopedStatsReceiver.scope(ViewerFollowsAuthorOfViolatingTweet.name).counter("requests") + + private[this] val viewerDoesNotFollowAuthorOfViolatingTweet = + scopedStatsReceiver.scope(ViewerDoesNotFollowAuthorOfViolatingTweet.name).counter("requests") + + def forNonFosnr(): FeatureMapBuilder => FeatureMapBuilder = { + requests.incr() + _.withConstantFeature(ViewerFollowsAuthorOfViolatingTweet, false) + .withConstantFeature(ViewerDoesNotFollowAuthorOfViolatingTweet, false) + } + def forTweetWithSafetyLabelsAndAuthorId( + safetyLabels: Seq[TweetSafetyLabel], + authorId: Long, + viewerId: Option[Long] + ): FeatureMapBuilder => FeatureMapBuilder = { + requests.incr() + _.withFeature( + ViewerFollowsAuthorOfViolatingTweet, + viewerFollowsAuthorOfViolatingTweet(safetyLabels, authorId, viewerId)) + .withFeature( + ViewerDoesNotFollowAuthorOfViolatingTweet, + viewerDoesNotFollowAuthorOfViolatingTweet(safetyLabels, authorId, viewerId)) + } + def viewerFollowsAuthorOfViolatingTweet( + safetyLabels: Seq[TweetSafetyLabel], + authorId: UserId, + viewerId: Option[UserId] + ): Stitch[Boolean] = { + if (safetyLabels + .map(ViolationLevel.fromTweetSafetyLabelOpt).collect { + case Some(level) => level + }.isEmpty) { + return Stitch.False + } + ViewerVerbsAuthor( + authorId, + viewerId, + userRelationshipSource.follows, + viewerFollowsAuthorOfViolatingTweet) + } + def viewerDoesNotFollowAuthorOfViolatingTweet( + safetyLabels: Seq[TweetSafetyLabel], + authorId: UserId, + viewerId: Option[UserId] + ): Stitch[Boolean] = { + if (safetyLabels + .map(ViolationLevel.fromTweetSafetyLabelOpt).collect { + case Some(level) => level + }.isEmpty) { + return Stitch.False + } + ViewerVerbsAuthor( + authorId, + viewerId, + userRelationshipSource.follows, + viewerDoesNotFollowAuthorOfViolatingTweet).map(following => !following) + } + +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/FosnrRelationshipFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/FosnrRelationshipFeatures.scala new file mode 100644 index 000000000..a6758eefa --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/FosnrRelationshipFeatures.scala @@ -0,0 +1,82 @@ +package com.twitter.visibility.builder.tweets + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.stitch.Stitch +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.builder.users.ViewerVerbsAuthor +import com.twitter.visibility.common.UserId +import com.twitter.visibility.common.UserRelationshipSource +import com.twitter.visibility.features._ +import com.twitter.visibility.models.ViolationLevel + +class FosnrRelationshipFeatures( + tweetLabels: TweetLabels, + userRelationshipSource: UserRelationshipSource, + statsReceiver: StatsReceiver) { + + private[this] val scopedStatsReceiver = statsReceiver.scope("fonsr_relationship_features") + + private[this] val requests = scopedStatsReceiver.counter("requests") + + private[this] val viewerFollowsAuthorOfViolatingTweet = + scopedStatsReceiver.scope(ViewerFollowsAuthorOfViolatingTweet.name).counter("requests") + + private[this] val viewerDoesNotFollowAuthorOfViolatingTweet = + scopedStatsReceiver.scope(ViewerDoesNotFollowAuthorOfViolatingTweet.name).counter("requests") + + def forTweetAndAuthorId( + tweet: Tweet, + authorId: Long, + viewerId: Option[Long] + ): FeatureMapBuilder => FeatureMapBuilder = { + requests.incr() + _.withFeature( + ViewerFollowsAuthorOfViolatingTweet, + viewerFollowsAuthorOfViolatingTweet(tweet, authorId, viewerId)) + .withFeature( + ViewerDoesNotFollowAuthorOfViolatingTweet, + viewerDoesNotFollowAuthorOfViolatingTweet(tweet, authorId, viewerId)) + } + + def viewerFollowsAuthorOfViolatingTweet( + tweet: Tweet, + authorId: UserId, + viewerId: Option[UserId] + ): Stitch[Boolean] = + tweetLabels.forTweet(tweet).flatMap { safetyLabels => + if (safetyLabels + .map(ViolationLevel.fromTweetSafetyLabelOpt).collect { + case Some(level) => level + }.isEmpty) { + Stitch.False + } else { + ViewerVerbsAuthor( + authorId, + viewerId, + userRelationshipSource.follows, + viewerFollowsAuthorOfViolatingTweet) + } + } + + def viewerDoesNotFollowAuthorOfViolatingTweet( + tweet: Tweet, + authorId: UserId, + viewerId: Option[UserId] + ): Stitch[Boolean] = + tweetLabels.forTweet(tweet).flatMap { safetyLabels => + if (safetyLabels + .map(ViolationLevel.fromTweetSafetyLabelOpt).collect { + case Some(level) => level + }.isEmpty) { + Stitch.False + } else { + ViewerVerbsAuthor( + authorId, + viewerId, + userRelationshipSource.follows, + viewerDoesNotFollowAuthorOfViolatingTweet).map(following => !following) + } + } + +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/MisinformationPolicyFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/MisinformationPolicyFeatures.scala new file mode 100644 index 000000000..49d01dc14 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/MisinformationPolicyFeatures.scala @@ -0,0 +1,86 @@ +package com.twitter.visibility.builder.tweets + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.stitch.Stitch +import com.twitter.tweetypie.thriftscala.EscherbirdEntityAnnotations +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.common.MisinformationPolicySource +import com.twitter.visibility.features._ +import com.twitter.visibility.models.MisinformationPolicy +import com.twitter.visibility.models.SemanticCoreMisinformation +import com.twitter.visibility.models.ViewerContext + +class MisinformationPolicyFeatures( + misinformationPolicySource: MisinformationPolicySource, + statsReceiver: StatsReceiver) { + + private[this] val scopedStatsReceiver = + statsReceiver.scope("misinformation_policy_features") + + private[this] val requests = scopedStatsReceiver.counter("requests") + private[this] val tweetMisinformationPolicies = + scopedStatsReceiver.scope(TweetMisinformationPolicies.name).counter("requests") + + def forTweet( + tweet: Tweet, + viewerContext: ViewerContext + ): FeatureMapBuilder => FeatureMapBuilder = { + requests.incr() + tweetMisinformationPolicies.incr() + + _.withFeature( + TweetMisinformationPolicies, + misinformationPolicy(tweet.escherbirdEntityAnnotations, viewerContext)) + .withFeature( + TweetEnglishMisinformationPolicies, + misinformationPolicyEnglishOnly(tweet.escherbirdEntityAnnotations)) + } + + def misinformationPolicyEnglishOnly( + escherbirdEntityAnnotations: Option[EscherbirdEntityAnnotations], + ): Stitch[Seq[MisinformationPolicy]] = { + val locale = Some( + MisinformationPolicySource.LanguageAndCountry( + language = Some("en"), + country = Some("us") + )) + fetchMisinformationPolicy(escherbirdEntityAnnotations, locale) + } + + def misinformationPolicy( + escherbirdEntityAnnotations: Option[EscherbirdEntityAnnotations], + viewerContext: ViewerContext + ): Stitch[Seq[MisinformationPolicy]] = { + val locale = viewerContext.requestLanguageCode.map { language => + MisinformationPolicySource.LanguageAndCountry( + language = Some(language), + country = viewerContext.requestCountryCode + ) + } + fetchMisinformationPolicy(escherbirdEntityAnnotations, locale) + } + + def fetchMisinformationPolicy( + escherbirdEntityAnnotations: Option[EscherbirdEntityAnnotations], + locale: Option[MisinformationPolicySource.LanguageAndCountry] + ): Stitch[Seq[MisinformationPolicy]] = { + Stitch.collect( + escherbirdEntityAnnotations + .map(_.entityAnnotations) + .getOrElse(Seq.empty) + .filter(_.domainId == SemanticCoreMisinformation.domainId) + .map(annotation => + misinformationPolicySource + .fetch( + annotation, + locale + ) + .map(misinformation => + MisinformationPolicy( + annotation = annotation, + misinformation = misinformation + ))) + ) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/ModerationFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/ModerationFeatures.scala new file mode 100644 index 000000000..08fc4ad3b --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/ModerationFeatures.scala @@ -0,0 +1,23 @@ +package com.twitter.visibility.builder.tweets + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.features.TweetIsModerated + +class ModerationFeatures(moderationSource: Long => Boolean, statsReceiver: StatsReceiver) { + + private[this] val scopedStatsReceiver: StatsReceiver = + statsReceiver.scope("moderation_features") + + private[this] val requests = scopedStatsReceiver.counter("requests") + + private[this] val tweetIsModerated = + scopedStatsReceiver.scope(TweetIsModerated.name).counter("requests") + + def forTweetId(tweetId: Long): FeatureMapBuilder => FeatureMapBuilder = { featureMapBuilder => + requests.incr() + tweetIsModerated.incr() + + featureMapBuilder.withConstantFeature(TweetIsModerated, moderationSource(tweetId)) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/SearchContextFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/SearchContextFeatures.scala new file mode 100644 index 000000000..8be075ceb --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/SearchContextFeatures.scala @@ -0,0 +1,44 @@ +package com.twitter.visibility.builder.tweets + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.search.common.constants.thriftscala.ThriftQuerySource +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.features.SearchCandidateCount +import com.twitter.visibility.features.SearchQueryHasUser +import com.twitter.visibility.features.SearchQuerySource +import com.twitter.visibility.features.SearchResultsPageNumber +import com.twitter.visibility.interfaces.common.search.SearchVFRequestContext + +class SearchContextFeatures( + statsReceiver: StatsReceiver) { + private[this] val scopedStatsReceiver = statsReceiver.scope("search_context_features") + private[this] val requests = scopedStatsReceiver.counter("requests") + private[this] val searchResultsPageNumber = + scopedStatsReceiver.scope(SearchResultsPageNumber.name).counter("requests") + private[this] val searchCandidateCount = + scopedStatsReceiver.scope(SearchCandidateCount.name).counter("requests") + private[this] val searchQuerySource = + scopedStatsReceiver.scope(SearchQuerySource.name).counter("requests") + private[this] val searchQueryHasUser = + scopedStatsReceiver.scope(SearchQueryHasUser.name).counter("requests") + + def forSearchContext( + searchContext: SearchVFRequestContext + ): FeatureMapBuilder => FeatureMapBuilder = { + requests.incr() + searchResultsPageNumber.incr() + searchCandidateCount.incr() + searchQuerySource.incr() + searchQueryHasUser.incr() + + _.withConstantFeature(SearchResultsPageNumber, searchContext.resultsPageNumber) + .withConstantFeature(SearchCandidateCount, searchContext.candidateCount) + .withConstantFeature( + SearchQuerySource, + searchContext.querySourceOption match { + case Some(querySource) => querySource + case _ => ThriftQuerySource.Unknown + }) + .withConstantFeature(SearchQueryHasUser, searchContext.queryHasUser) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/ToxicReplyFilterFeature.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/ToxicReplyFilterFeature.scala new file mode 100644 index 000000000..04af003f0 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/ToxicReplyFilterFeature.scala @@ -0,0 +1,57 @@ +package com.twitter.visibility.builder.tweets + +import com.twitter.contenthealth.toxicreplyfilter.thriftscala.FilterState +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.features.ToxicReplyFilterConversationAuthorIsViewer +import com.twitter.visibility.features.ToxicReplyFilterState + +class ToxicReplyFilterFeature( + statsReceiver: StatsReceiver) { + + def forTweet(tweet: Tweet, viewerId: Option[Long]): FeatureMapBuilder => FeatureMapBuilder = { + builder => + requests.incr() + + builder + .withConstantFeature(ToxicReplyFilterState, isTweetFilteredFromAuthor(tweet)) + .withConstantFeature( + ToxicReplyFilterConversationAuthorIsViewer, + isRootAuthorViewer(tweet, viewerId)) + } + + private[this] def isRootAuthorViewer(tweet: Tweet, maybeViewerId: Option[Long]): Boolean = { + val maybeAuthorId = tweet.filteredReplyDetails.map(_.conversationAuthorId) + + (maybeViewerId, maybeAuthorId) match { + case (Some(viewerId), Some(authorId)) if viewerId == authorId => { + rootAuthorViewerStats.incr() + true + } + case _ => false + } + } + + private[this] def isTweetFilteredFromAuthor( + tweet: Tweet, + ): FilterState = { + val result = tweet.filteredReplyDetails.map(_.filterState).getOrElse(FilterState.Unfiltered) + + if (result == FilterState.FilteredFromAuthor) { + filteredFromAuthorStats.incr() + } + result + } + + private[this] val scopedStatsReceiver = + statsReceiver.scope("toxicreplyfilter") + + private[this] val requests = scopedStatsReceiver.counter("requests") + + private[this] val rootAuthorViewerStats = + scopedStatsReceiver.scope(ToxicReplyFilterConversationAuthorIsViewer.name).counter("requests") + + private[this] val filteredFromAuthorStats = + scopedStatsReceiver.scope(ToxicReplyFilterState.name).counter("requests") +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/TrustedFriendsFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/TrustedFriendsFeatures.scala new file mode 100644 index 000000000..c381c47b3 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/TrustedFriendsFeatures.scala @@ -0,0 +1,57 @@ +package com.twitter.visibility.builder.tweets + +import com.twitter.stitch.Stitch +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.common.TrustedFriendsListId +import com.twitter.visibility.common.TrustedFriendsSource +import com.twitter.visibility.features.TweetIsTrustedFriendTweet +import com.twitter.visibility.features.ViewerIsTrustedFriendOfTweetAuthor +import com.twitter.visibility.features.ViewerIsTrustedFriendTweetAuthor + +class TrustedFriendsFeatures(trustedFriendsSource: TrustedFriendsSource) { + + private[builder] def viewerIsTrustedFriend( + tweet: Tweet, + viewerId: Option[Long] + ): Stitch[Boolean] = + (trustedFriendsListId(tweet), viewerId) match { + case (Some(tfListId), Some(userId)) => + trustedFriendsSource.isTrustedFriend(tfListId, userId) + case _ => Stitch.False + } + + private[builder] def viewerIsTrustedFriendListOwner( + tweet: Tweet, + viewerId: Option[Long] + ): Stitch[Boolean] = + (trustedFriendsListId(tweet), viewerId) match { + case (Some(tfListId), Some(userId)) => + trustedFriendsSource.isTrustedFriendListOwner(tfListId, userId) + case _ => Stitch.False + } + + private[builder] def trustedFriendsListId(tweet: Tweet): Option[TrustedFriendsListId] = + tweet.trustedFriendsControl.map(_.trustedFriendsListId) + + def forTweet( + tweet: Tweet, + viewerId: Option[Long] + ): FeatureMapBuilder => FeatureMapBuilder = { + _.withConstantFeature( + TweetIsTrustedFriendTweet, + tweet.trustedFriendsControl.isDefined + ).withFeature( + ViewerIsTrustedFriendTweetAuthor, + viewerIsTrustedFriendListOwner(tweet, viewerId) + ).withFeature( + ViewerIsTrustedFriendOfTweetAuthor, + viewerIsTrustedFriend(tweet, viewerId) + ) + } + + def forTweetOnly(tweet: Tweet): FeatureMapBuilder => FeatureMapBuilder = { + _.withConstantFeature(TweetIsTrustedFriendTweet, tweet.trustedFriendsControl.isDefined) + } + +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/TweetFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/TweetFeatures.scala new file mode 100644 index 000000000..e28f0bda8 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/TweetFeatures.scala @@ -0,0 +1,210 @@ +package com.twitter.visibility.builder.tweets + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.snowflake.id.SnowflakeId +import com.twitter.stitch.Stitch +import com.twitter.tweetypie.thriftscala.CollabControl +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.util.Duration +import com.twitter.util.Time +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.common.SafetyLabelMapSource +import com.twitter.visibility.common.TweetId +import com.twitter.visibility.common.UserId +import com.twitter.visibility.features._ +import com.twitter.visibility.models.SemanticCoreAnnotation +import com.twitter.visibility.models.TweetSafetyLabel + +object TweetFeatures { + + def FALLBACK_TIMESTAMP: Time = Time.epoch + + def tweetIsSelfReply(tweet: Tweet): Boolean = { + tweet.coreData match { + case Some(coreData) => + coreData.reply match { + case Some(reply) => + reply.inReplyToUserId == coreData.userId + + case None => + false + } + + case None => + false + } + } + + def tweetReplyToParentTweetDuration(tweet: Tweet): Option[Duration] = for { + coreData <- tweet.coreData + reply <- coreData.reply + inReplyToStatusId <- reply.inReplyToStatusId + replyTime <- SnowflakeId.timeFromIdOpt(tweet.id) + repliedToTime <- SnowflakeId.timeFromIdOpt(inReplyToStatusId) + } yield { + replyTime.diff(repliedToTime) + } + + def tweetReplyToRootTweetDuration(tweet: Tweet): Option[Duration] = for { + coreData <- tweet.coreData + if coreData.reply.isDefined + conversationId <- coreData.conversationId + replyTime <- SnowflakeId.timeFromIdOpt(tweet.id) + rootTime <- SnowflakeId.timeFromIdOpt(conversationId) + } yield { + replyTime.diff(rootTime) + } + + def tweetTimestamp(tweetId: Long): Time = + SnowflakeId.timeFromIdOpt(tweetId).getOrElse(FALLBACK_TIMESTAMP) + + def tweetSemanticCoreAnnotations(tweet: Tweet): Seq[SemanticCoreAnnotation] = { + tweet.escherbirdEntityAnnotations + .map(a => + a.entityAnnotations.map { annotation => + SemanticCoreAnnotation( + annotation.groupId, + annotation.domainId, + annotation.entityId + ) + }).toSeq.flatten + } + + def tweetIsNullcast(tweet: Tweet): Boolean = { + tweet.coreData match { + case Some(coreData) => + coreData.nullcast + case None => + false + } + } + + def tweetAuthorUserId(tweet: Tweet): Option[UserId] = { + tweet.coreData.map(_.userId) + } +} + +sealed trait TweetLabels { + def forTweet(tweet: Tweet): Stitch[Seq[TweetSafetyLabel]] + def forTweetId(tweetId: TweetId): Stitch[Seq[TweetSafetyLabel]] +} + +class StratoTweetLabelMaps(safetyLabelSource: SafetyLabelMapSource) extends TweetLabels { + + override def forTweet(tweet: Tweet): Stitch[Seq[TweetSafetyLabel]] = { + forTweetId(tweet.id) + } + + def forTweetId(tweetId: TweetId): Stitch[Seq[TweetSafetyLabel]] = { + safetyLabelSource + .fetch(tweetId).map( + _.map( + _.labels + .map( + _.map(sl => TweetSafetyLabel.fromTuple(sl._1, sl._2)).toSeq + ).getOrElse(Seq()) + ).getOrElse(Seq())) + } +} + +object NilTweetLabelMaps extends TweetLabels { + override def forTweet(tweet: Tweet): Stitch[Seq[TweetSafetyLabel]] = Stitch.Nil + override def forTweetId(tweetId: TweetId): Stitch[Seq[TweetSafetyLabel]] = Stitch.Nil +} + +class TweetFeatures(tweetLabels: TweetLabels, statsReceiver: StatsReceiver) { + private[this] val scopedStatsReceiver = statsReceiver.scope("tweet_features") + + private[this] val requests = scopedStatsReceiver.counter("requests") + private[this] val tweetSafetyLabels = + scopedStatsReceiver.scope(TweetSafetyLabels.name).counter("requests") + private[this] val tweetTakedownReasons = + scopedStatsReceiver.scope(TweetTakedownReasons.name).counter("requests") + private[this] val tweetIsSelfReply = + scopedStatsReceiver.scope(TweetIsSelfReply.name).counter("requests") + private[this] val tweetTimestamp = + scopedStatsReceiver.scope(TweetTimestamp.name).counter("requests") + private[this] val tweetReplyToParentTweetDuration = + scopedStatsReceiver.scope(TweetReplyToParentTweetDuration.name).counter("requests") + private[this] val tweetReplyToRootTweetDuration = + scopedStatsReceiver.scope(TweetReplyToRootTweetDuration.name).counter("requests") + private[this] val tweetSemanticCoreAnnotations = + scopedStatsReceiver.scope(TweetSemanticCoreAnnotations.name).counter("requests") + private[this] val tweetId = + scopedStatsReceiver.scope(TweetId.name).counter("requests") + private[this] val tweetHasNsfwUser = + scopedStatsReceiver.scope(TweetHasNsfwUser.name).counter("requests") + private[this] val tweetHasNsfwAdmin = + scopedStatsReceiver.scope(TweetHasNsfwAdmin.name).counter("requests") + private[this] val tweetIsNullcast = + scopedStatsReceiver.scope(TweetIsNullcast.name).counter("requests") + private[this] val tweetHasMedia = + scopedStatsReceiver.scope(TweetHasMedia.name).counter("requests") + private[this] val tweetIsCommunity = + scopedStatsReceiver.scope(TweetIsCommunityTweet.name).counter("requests") + private[this] val tweetIsCollabInvitation = + scopedStatsReceiver.scope(TweetIsCollabInvitationTweet.name).counter("requests") + + def forTweet(tweet: Tweet): FeatureMapBuilder => FeatureMapBuilder = { + forTweetWithoutSafetyLabels(tweet) + .andThen(_.withFeature(TweetSafetyLabels, tweetLabels.forTweet(tweet))) + } + + def forTweetWithoutSafetyLabels(tweet: Tweet): FeatureMapBuilder => FeatureMapBuilder = { + requests.incr() + + tweetTakedownReasons.incr() + tweetIsSelfReply.incr() + tweetTimestamp.incr() + tweetReplyToParentTweetDuration.incr() + tweetReplyToRootTweetDuration.incr() + tweetSemanticCoreAnnotations.incr() + tweetId.incr() + tweetHasNsfwUser.incr() + tweetHasNsfwAdmin.incr() + tweetIsNullcast.incr() + tweetHasMedia.incr() + tweetIsCommunity.incr() + tweetIsCollabInvitation.incr() + + _.withConstantFeature(TweetTakedownReasons, tweet.takedownReasons.getOrElse(Seq.empty)) + .withConstantFeature(TweetIsSelfReply, TweetFeatures.tweetIsSelfReply(tweet)) + .withConstantFeature(TweetTimestamp, TweetFeatures.tweetTimestamp(tweet.id)) + .withConstantFeature( + TweetReplyToParentTweetDuration, + TweetFeatures.tweetReplyToParentTweetDuration(tweet)) + .withConstantFeature( + TweetReplyToRootTweetDuration, + TweetFeatures.tweetReplyToRootTweetDuration(tweet)) + .withConstantFeature( + TweetSemanticCoreAnnotations, + TweetFeatures.tweetSemanticCoreAnnotations(tweet)) + .withConstantFeature(TweetId, tweet.id) + .withConstantFeature(TweetHasNsfwUser, tweetHasNsfwUser(tweet)) + .withConstantFeature(TweetHasNsfwAdmin, tweetHasNsfwAdmin(tweet)) + .withConstantFeature(TweetIsNullcast, TweetFeatures.tweetIsNullcast(tweet)) + .withConstantFeature(TweetHasMedia, tweetHasMedia(tweet)) + .withConstantFeature(TweetIsCommunityTweet, tweetHasCommunity(tweet)) + .withConstantFeature(TweetIsCollabInvitationTweet, tweetIsCollabInvitation(tweet)) + } + + def tweetHasNsfwUser(tweet: Tweet): Boolean = + tweet.coreData.exists(_.nsfwUser) + + def tweetHasNsfwAdmin(tweet: Tweet): Boolean = + tweet.coreData.exists(_.nsfwAdmin) + + def tweetHasMedia(tweet: Tweet): Boolean = + tweet.coreData.exists(_.hasMedia.getOrElse(false)) + + def tweetHasCommunity(tweet: Tweet): Boolean = { + tweet.communities.exists(_.communityIds.nonEmpty) + } + + def tweetIsCollabInvitation(tweet: Tweet): Boolean = { + tweet.collabControl.exists(_ match { + case CollabControl.CollabInvitation(_) => true + case _ => false + }) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/TweetIdFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/TweetIdFeatures.scala new file mode 100644 index 000000000..b284b4ab6 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/TweetIdFeatures.scala @@ -0,0 +1,76 @@ +package com.twitter.visibility.builder.tweets + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.servo.util.Gate +import com.twitter.spam.rtf.thriftscala.SafetyLabel +import com.twitter.spam.rtf.thriftscala.SafetyLabelType +import com.twitter.spam.rtf.thriftscala.SafetyLabelValue +import com.twitter.stitch.Stitch +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.common.stitch.StitchHelpers +import com.twitter.visibility.features.TweetId +import com.twitter.visibility.features.TweetSafetyLabels +import com.twitter.visibility.features.TweetTimestamp +import com.twitter.visibility.models.TweetSafetyLabel + +class TweetIdFeatures( + statsReceiver: StatsReceiver, + enableStitchProfiling: Gate[Unit]) { + private[this] val scopedStatsReceiver: StatsReceiver = statsReceiver.scope("tweet_id_features") + + private[this] val requests = scopedStatsReceiver.counter("requests") + private[this] val tweetSafetyLabels = + scopedStatsReceiver.scope(TweetSafetyLabels.name).counter("requests") + private[this] val tweetTimestamp = + scopedStatsReceiver.scope(TweetTimestamp.name).counter("requests") + + private[this] val labelFetchScope: StatsReceiver = + scopedStatsReceiver.scope("labelFetch") + + private[this] def getTweetLabels( + tweetId: Long, + labelFetcher: Long => Stitch[Map[SafetyLabelType, SafetyLabel]] + ): Stitch[Seq[TweetSafetyLabel]] = { + val stitch = + labelFetcher(tweetId).map { labelMap => + labelMap + .map { case (labelType, label) => SafetyLabelValue(labelType, label) }.toSeq + .map(TweetSafetyLabel.fromThrift) + } + + if (enableStitchProfiling()) { + StitchHelpers.profileStitch( + stitch, + Seq(labelFetchScope) + ) + } else { + stitch + } + } + + def forTweetId( + tweetId: Long, + labelFetcher: Long => Stitch[Map[SafetyLabelType, SafetyLabel]] + ): FeatureMapBuilder => FeatureMapBuilder = { + requests.incr() + tweetSafetyLabels.incr() + tweetTimestamp.incr() + + _.withFeature(TweetSafetyLabels, getTweetLabels(tweetId, labelFetcher)) + .withConstantFeature(TweetTimestamp, TweetFeatures.tweetTimestamp(tweetId)) + .withConstantFeature(TweetId, tweetId) + } + + def forTweetId( + tweetId: Long, + constantTweetSafetyLabels: Seq[TweetSafetyLabel] + ): FeatureMapBuilder => FeatureMapBuilder = { + requests.incr() + tweetSafetyLabels.incr() + tweetTimestamp.incr() + + _.withConstantFeature(TweetSafetyLabels, constantTweetSafetyLabels) + .withConstantFeature(TweetTimestamp, TweetFeatures.tweetTimestamp(tweetId)) + .withConstantFeature(TweetId, tweetId) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/TweetMediaMetadataFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/TweetMediaMetadataFeatures.scala new file mode 100644 index 000000000..e421bd7b6 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/TweetMediaMetadataFeatures.scala @@ -0,0 +1,130 @@ +package com.twitter.visibility.builder.tweets + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.mediaservices.commons.mediainformation.thriftscala.AdditionalMetadata +import com.twitter.mediaservices.media_util.GenericMediaKey +import com.twitter.stitch.Stitch +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.common.TweetMediaMetadataSource +import com.twitter.visibility.features.HasDmcaMediaFeature +import com.twitter.visibility.features.MediaGeoRestrictionsAllowList +import com.twitter.visibility.features.MediaGeoRestrictionsDenyList + +class TweetMediaMetadataFeatures( + mediaMetadataSource: TweetMediaMetadataSource, + statsReceiver: StatsReceiver) { + + private[this] val scopedStatsReceiver = statsReceiver.scope("tweet_media_metadata_features") + private[this] val reportedStats = scopedStatsReceiver.scope("dmcaStats") + + def forTweet( + tweet: Tweet, + mediaKeys: Seq[GenericMediaKey], + enableFetchMediaMetadata: Boolean + ): FeatureMapBuilder => FeatureMapBuilder = { featureMapBuilder => + featureMapBuilder.withFeature( + HasDmcaMediaFeature, + mediaIsDmca(tweet, mediaKeys, enableFetchMediaMetadata)) + featureMapBuilder.withFeature( + MediaGeoRestrictionsAllowList, + allowlist(tweet, mediaKeys, enableFetchMediaMetadata)) + featureMapBuilder.withFeature( + MediaGeoRestrictionsDenyList, + denylist(tweet, mediaKeys, enableFetchMediaMetadata)) + } + + private def mediaIsDmca( + tweet: Tweet, + mediaKeys: Seq[GenericMediaKey], + enableFetchMediaMetadata: Boolean + ) = getMediaAdditionalMetadata(tweet, mediaKeys, enableFetchMediaMetadata) + .map(_.exists(_.restrictions.exists(_.isDmca))) + + private def allowlist( + tweet: Tweet, + mediaKeys: Seq[GenericMediaKey], + enableFetchMediaMetadata: Boolean + ) = getMediaGeoRestrictions(tweet, mediaKeys, enableFetchMediaMetadata) + .map(_.flatMap(_.whitelistedCountryCodes)) + + private def denylist( + tweet: Tweet, + mediaKeys: Seq[GenericMediaKey], + enableFetchMediaMetadata: Boolean + ) = getMediaGeoRestrictions(tweet, mediaKeys, enableFetchMediaMetadata) + .map(_.flatMap(_.blacklistedCountryCodes)) + + private def getMediaGeoRestrictions( + tweet: Tweet, + mediaKeys: Seq[GenericMediaKey], + enableFetchMediaMetadata: Boolean + ) = { + getMediaAdditionalMetadata(tweet, mediaKeys, enableFetchMediaMetadata) + .map(additionalMetadatasSeq => { + for { + additionalMetadata <- additionalMetadatasSeq + restrictions <- additionalMetadata.restrictions + geoRestrictions <- restrictions.geoRestrictions + } yield { + geoRestrictions + } + }) + } + + private def getMediaAdditionalMetadata( + tweet: Tweet, + mediaKeys: Seq[GenericMediaKey], + enableFetchMediaMetadata: Boolean + ): Stitch[Seq[AdditionalMetadata]] = { + if (mediaKeys.isEmpty) { + reportedStats.counter("empty").incr() + Stitch.value(Seq.empty) + } else { + tweet.media.flatMap { mediaEntities => + val alreadyHydratedMetadata = mediaEntities + .filter(_.mediaKey.isDefined) + .flatMap(_.additionalMetadata) + + if (alreadyHydratedMetadata.nonEmpty) { + Some(alreadyHydratedMetadata) + } else { + None + } + } match { + case Some(additionalMetadata) => + reportedStats.counter("already_hydrated").incr() + Stitch.value(additionalMetadata) + case None => + Stitch + .collect( + mediaKeys.map(fetchAdditionalMetadata(tweet.id, _, enableFetchMediaMetadata)) + ).map(maybeMetadatas => { + maybeMetadatas + .filter(_.isDefined) + .map(_.get) + }) + } + } + } + + private def fetchAdditionalMetadata( + tweetId: Long, + genericMediaKey: GenericMediaKey, + enableFetchMediaMetadata: Boolean + ): Stitch[Option[AdditionalMetadata]] = + if (enableFetchMediaMetadata) { + genericMediaKey.toThriftMediaKey() match { + case Some(mediaKey) => + reportedStats.counter("request").incr() + mediaMetadataSource.fetch(tweetId, mediaKey) + case None => + reportedStats.counter("empty_key").incr() + Stitch.None + } + } else { + reportedStats.counter("light_request").incr() + Stitch.None + } + +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/TweetPerspectiveFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/TweetPerspectiveFeatures.scala new file mode 100644 index 000000000..e2e25740b --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/TweetPerspectiveFeatures.scala @@ -0,0 +1,54 @@ +package com.twitter.visibility.builder.tweets + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.stitch.Stitch +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.common.TweetPerspectiveSource +import com.twitter.visibility.features.ViewerReportedTweet + +class TweetPerspectiveFeatures( + tweetPerspectiveSource: TweetPerspectiveSource, + statsReceiver: StatsReceiver) { + + private[this] val scopedStatsReceiver = statsReceiver.scope("tweet_perspective_features") + private[this] val reportedStats = scopedStatsReceiver.scope("reported") + + def forTweet( + tweet: Tweet, + viewerId: Option[Long], + enableFetchReportedPerspective: Boolean + ): FeatureMapBuilder => FeatureMapBuilder = + _.withFeature( + ViewerReportedTweet, + tweetIsReported(tweet, viewerId, enableFetchReportedPerspective)) + + private[builder] def tweetIsReported( + tweet: Tweet, + viewerId: Option[Long], + enableFetchReportedPerspective: Boolean = true + ): Stitch[Boolean] = { + ((tweet.perspective, viewerId) match { + case (Some(perspective), _) => + Stitch.value(perspective.reported).onSuccess { _ => + reportedStats.counter("already_hydrated").incr() + } + case (None, Some(viewerId)) => + if (enableFetchReportedPerspective) { + tweetPerspectiveSource.reported(tweet.id, viewerId).onSuccess { _ => + reportedStats.counter("request").incr() + } + } else { + Stitch.False.onSuccess { _ => + reportedStats.counter("light_request").incr() + } + } + case _ => + Stitch.False.onSuccess { _ => + reportedStats.counter("empty").incr() + } + }).onSuccess { _ => + reportedStats.counter("").incr() + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/TweetVisibilityNudgeSourceWrapper.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/TweetVisibilityNudgeSourceWrapper.scala new file mode 100644 index 000000000..99ad6a46a --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/TweetVisibilityNudgeSourceWrapper.scala @@ -0,0 +1,39 @@ +package com.twitter.visibility.builder.tweets + +import com.twitter.spam.rtf.thriftscala.SafetyLabelType +import com.twitter.spam.rtf.thriftscala.SafetyLabelType.ExperimentalNudge +import com.twitter.spam.rtf.thriftscala.SafetyLabelType.SemanticCoreMisinformation +import com.twitter.spam.rtf.thriftscala.SafetyLabelType.UnsafeUrl +import com.twitter.visibility.common.LocalizedNudgeSource +import com.twitter.visibility.common.actions.TweetVisibilityNudgeReason +import com.twitter.visibility.common.actions.TweetVisibilityNudgeReason.ExperimentalNudgeSafetyLabelReason +import com.twitter.visibility.common.actions.TweetVisibilityNudgeReason.SemanticCoreMisinformationLabelReason +import com.twitter.visibility.common.actions.TweetVisibilityNudgeReason.UnsafeURLLabelReason +import com.twitter.visibility.rules.LocalizedNudge + +class TweetVisibilityNudgeSourceWrapper(localizedNudgeSource: LocalizedNudgeSource) { + + def getLocalizedNudge( + reason: TweetVisibilityNudgeReason, + languageCode: String, + countryCode: Option[String] + ): Option[LocalizedNudge] = + reason match { + case ExperimentalNudgeSafetyLabelReason => + fetchNudge(ExperimentalNudge, languageCode, countryCode) + case SemanticCoreMisinformationLabelReason => + fetchNudge(SemanticCoreMisinformation, languageCode, countryCode) + case UnsafeURLLabelReason => + fetchNudge(UnsafeUrl, languageCode, countryCode) + } + + private def fetchNudge( + safetyLabel: SafetyLabelType, + languageCode: String, + countryCode: Option[String] + ): Option[LocalizedNudge] = { + localizedNudgeSource + .fetch(safetyLabel, languageCode, countryCode) + .map(LocalizedNudge.fromStratoThrift) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/UnmentionNotificationFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/UnmentionNotificationFeatures.scala new file mode 100644 index 000000000..513627009 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/tweets/UnmentionNotificationFeatures.scala @@ -0,0 +1,75 @@ +package com.twitter.visibility.builder.tweets + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.notificationservice.model.notification._ +import com.twitter.servo.util.Gate +import com.twitter.stitch.Stitch +import com.twitter.tweetypie.thriftscala.SettingsUnmentions +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.common.TweetSource +import com.twitter.visibility.features.NotificationIsOnUnmentionedViewer + +object UnmentionNotificationFeatures { + def ForNonUnmentionNotificationFeatures: FeatureMapBuilder => FeatureMapBuilder = { + _.withConstantFeature(NotificationIsOnUnmentionedViewer, false) + } +} + +class UnmentionNotificationFeatures( + tweetSource: TweetSource, + enableUnmentionHydration: Gate[Long], + statsReceiver: StatsReceiver) { + + private[this] val scopedStatsReceiver = + statsReceiver.scope("unmention_notification_features") + private[this] val requestsCounter = scopedStatsReceiver.counter("requests") + private[this] val hydrationsCounter = scopedStatsReceiver.counter("hydrations") + private[this] val notificationsUnmentionUserCounter = + scopedStatsReceiver + .scope(NotificationIsOnUnmentionedViewer.name).counter("unmentioned_users") + + def forNotification(notification: Notification): FeatureMapBuilder => FeatureMapBuilder = { + requestsCounter.incr() + + val isUnmentionNotification = tweetId(notification) match { + case Some(tweetId) if enableUnmentionHydration(notification.target) => + hydrationsCounter.incr() + tweetSource + .getTweet(tweetId) + .map { + case Some(tweet) => + tweet.settingsUnmentions match { + case Some(SettingsUnmentions(Some(unmentionedUserIds))) => + if (unmentionedUserIds.contains(notification.target)) { + notificationsUnmentionUserCounter.incr() + true + } else { + false + } + case _ => false + } + case _ => false + } + case _ => Stitch.False + } + _.withFeature(NotificationIsOnUnmentionedViewer, isUnmentionNotification) + } + + private[this] def tweetId(notification: Notification): Option[Long] = { + notification match { + case n: MentionNotification => Some(n.mentioningTweetId) + case n: FavoritedMentioningTweetNotification => Some(n.mentioningTweetId) + case n: FavoritedReplyToYourTweetNotification => Some(n.replyTweetId) + case n: MentionQuoteNotification => Some(n.mentioningTweetId) + case n: ReactionMentioningTweetNotification => Some(n.mentioningTweetId) + case n: ReplyNotification => Some(n.replyingTweetId) + case n: RetweetedMentionNotification => Some(n.mentioningTweetId) + case n: RetweetedReplyToYourTweetNotification => Some(n.replyTweetId) + case n: ReplyToConversationNotification => Some(n.replyingTweetId) + case n: ReactionReplyToYourTweetNotification => Some(n.replyTweetId) + case _ => None + } + + } + +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/AuthorDeviceFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/AuthorDeviceFeatures.scala new file mode 100644 index 000000000..5ffa3c107 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/AuthorDeviceFeatures.scala @@ -0,0 +1,39 @@ +package com.twitter.visibility.builder.users + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.gizmoduck.thriftscala.User +import com.twitter.stitch.Stitch +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.common.UserDeviceSource +import com.twitter.visibility.features.AuthorHasConfirmedEmail +import com.twitter.visibility.features.AuthorHasVerifiedPhone + +class AuthorDeviceFeatures(userDeviceSource: UserDeviceSource, statsReceiver: StatsReceiver) { + private[this] val scopedStatsReceiver = statsReceiver.scope("author_device_features") + + private[this] val requests = scopedStatsReceiver.counter("requests") + + private[this] val authorHasConfirmedEmailRequests = + scopedStatsReceiver.scope(AuthorHasConfirmedEmail.name).counter("requests") + private[this] val authorHasVerifiedPhoneRequests = + scopedStatsReceiver.scope(AuthorHasVerifiedPhone.name).counter("requests") + + def forAuthor(author: User): FeatureMapBuilder => FeatureMapBuilder = forAuthorId(author.id) + + def forAuthorId(authorId: Long): FeatureMapBuilder => FeatureMapBuilder = { + requests.incr() + + _.withFeature(AuthorHasConfirmedEmail, authorHasConfirmedEmail(authorId)) + .withFeature(AuthorHasVerifiedPhone, authorHasVerifiedPhone(authorId)) + } + + def authorHasConfirmedEmail(authorId: Long): Stitch[Boolean] = { + authorHasConfirmedEmailRequests.incr() + userDeviceSource.hasConfirmedEmail(authorId) + } + + def authorHasVerifiedPhone(authorId: Long): Stitch[Boolean] = { + authorHasVerifiedPhoneRequests.incr() + userDeviceSource.hasConfirmedPhone(authorId) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/AuthorFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/AuthorFeatures.scala new file mode 100644 index 000000000..bf6529691 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/AuthorFeatures.scala @@ -0,0 +1,221 @@ +package com.twitter.visibility.builder.users + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.gizmoduck.thriftscala.Label +import com.twitter.gizmoduck.thriftscala.Labels +import com.twitter.gizmoduck.thriftscala.Profile +import com.twitter.gizmoduck.thriftscala.Safety +import com.twitter.gizmoduck.thriftscala.User +import com.twitter.stitch.NotFound +import com.twitter.stitch.Stitch +import com.twitter.tseng.withholding.thriftscala.TakedownReason +import com.twitter.util.Duration +import com.twitter.util.Time +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.common.UserSource +import com.twitter.visibility.features._ + +class AuthorFeatures(userSource: UserSource, statsReceiver: StatsReceiver) { + private[this] val scopedStatsReceiver = statsReceiver.scope("author_features") + + private[this] val requests = scopedStatsReceiver.counter("requests") + + private[this] val authorUserLabels = + scopedStatsReceiver.scope(AuthorUserLabels.name).counter("requests") + private[this] val authorIsSuspended = + scopedStatsReceiver.scope(AuthorIsSuspended.name).counter("requests") + private[this] val authorIsProtected = + scopedStatsReceiver.scope(AuthorIsProtected.name).counter("requests") + private[this] val authorIsDeactivated = + scopedStatsReceiver.scope(AuthorIsDeactivated.name).counter("requests") + private[this] val authorIsErased = + scopedStatsReceiver.scope(AuthorIsErased.name).counter("requests") + private[this] val authorIsOffboarded = + scopedStatsReceiver.scope(AuthorIsOffboarded.name).counter("requests") + private[this] val authorIsNsfwUser = + scopedStatsReceiver.scope(AuthorIsNsfwUser.name).counter("requests") + private[this] val authorIsNsfwAdmin = + scopedStatsReceiver.scope(AuthorIsNsfwAdmin.name).counter("requests") + private[this] val authorTakedownReasons = + scopedStatsReceiver.scope(AuthorTakedownReasons.name).counter("requests") + private[this] val authorHasDefaultProfileImage = + scopedStatsReceiver.scope(AuthorHasDefaultProfileImage.name).counter("requests") + private[this] val authorAccountAge = + scopedStatsReceiver.scope(AuthorAccountAge.name).counter("requests") + private[this] val authorIsVerified = + scopedStatsReceiver.scope(AuthorIsVerified.name).counter("requests") + private[this] val authorScreenName = + scopedStatsReceiver.scope(AuthorScreenName.name).counter("requests") + private[this] val authorIsBlueVerified = + scopedStatsReceiver.scope(AuthorIsBlueVerified.name).counter("requests") + + def forAuthor(author: User): FeatureMapBuilder => FeatureMapBuilder = { + requests.incr() + + _.withConstantFeature(AuthorId, Set(author.id)) + .withConstantFeature(AuthorUserLabels, authorUserLabels(author)) + .withConstantFeature(AuthorIsProtected, authorIsProtected(author)) + .withConstantFeature(AuthorIsSuspended, authorIsSuspended(author)) + .withConstantFeature(AuthorIsDeactivated, authorIsDeactivated(author)) + .withConstantFeature(AuthorIsErased, authorIsErased(author)) + .withConstantFeature(AuthorIsOffboarded, authorIsOffboarded(author)) + .withConstantFeature(AuthorTakedownReasons, authorTakedownReasons(author)) + .withConstantFeature(AuthorHasDefaultProfileImage, authorHasDefaultProfileImage(author)) + .withConstantFeature(AuthorAccountAge, authorAccountAge(author)) + .withConstantFeature(AuthorIsNsfwUser, authorIsNsfwUser(author)) + .withConstantFeature(AuthorIsNsfwAdmin, authorIsNsfwAdmin(author)) + .withConstantFeature(AuthorIsVerified, authorIsVerified(author)) + .withConstantFeature(AuthorScreenName, authorScreenName(author)) + .withConstantFeature(AuthorIsBlueVerified, authorIsBlueVerified(author)) + } + + def forAuthorNoDefaults(author: User): FeatureMapBuilder => FeatureMapBuilder = { + requests.incr() + + _.withConstantFeature(AuthorId, Set(author.id)) + .withConstantFeature(AuthorUserLabels, authorUserLabelsOpt(author)) + .withConstantFeature(AuthorIsProtected, authorIsProtectedOpt(author)) + .withConstantFeature(AuthorIsSuspended, authorIsSuspendedOpt(author)) + .withConstantFeature(AuthorIsDeactivated, authorIsDeactivatedOpt(author)) + .withConstantFeature(AuthorIsErased, authorIsErasedOpt(author)) + .withConstantFeature(AuthorIsOffboarded, authorIsOffboarded(author)) + .withConstantFeature(AuthorTakedownReasons, authorTakedownReasons(author)) + .withConstantFeature(AuthorHasDefaultProfileImage, authorHasDefaultProfileImage(author)) + .withConstantFeature(AuthorAccountAge, authorAccountAge(author)) + .withConstantFeature(AuthorIsNsfwUser, authorIsNsfwUserOpt(author)) + .withConstantFeature(AuthorIsNsfwAdmin, authorIsNsfwAdminOpt(author)) + .withConstantFeature(AuthorIsVerified, authorIsVerifiedOpt(author)) + .withConstantFeature(AuthorScreenName, authorScreenName(author)) + .withConstantFeature(AuthorIsBlueVerified, authorIsBlueVerified(author)) + } + + def forAuthorId(authorId: Long): FeatureMapBuilder => FeatureMapBuilder = { + requests.incr() + + _.withConstantFeature(AuthorId, Set(authorId)) + .withFeature(AuthorUserLabels, authorUserLabels(authorId)) + .withFeature(AuthorIsProtected, authorIsProtected(authorId)) + .withFeature(AuthorIsSuspended, authorIsSuspended(authorId)) + .withFeature(AuthorIsDeactivated, authorIsDeactivated(authorId)) + .withFeature(AuthorIsErased, authorIsErased(authorId)) + .withFeature(AuthorIsOffboarded, authorIsOffboarded(authorId)) + .withFeature(AuthorTakedownReasons, authorTakedownReasons(authorId)) + .withFeature(AuthorHasDefaultProfileImage, authorHasDefaultProfileImage(authorId)) + .withFeature(AuthorAccountAge, authorAccountAge(authorId)) + .withFeature(AuthorIsNsfwUser, authorIsNsfwUser(authorId)) + .withFeature(AuthorIsNsfwAdmin, authorIsNsfwAdmin(authorId)) + .withFeature(AuthorIsVerified, authorIsVerified(authorId)) + .withFeature(AuthorScreenName, authorScreenName(authorId)) + .withFeature(AuthorIsBlueVerified, authorIsBlueVerified(authorId)) + } + + def forNoAuthor(): FeatureMapBuilder => FeatureMapBuilder = { + _.withConstantFeature(AuthorId, Set.empty[Long]) + .withConstantFeature(AuthorUserLabels, Seq.empty) + .withConstantFeature(AuthorIsProtected, false) + .withConstantFeature(AuthorIsSuspended, false) + .withConstantFeature(AuthorIsDeactivated, false) + .withConstantFeature(AuthorIsErased, false) + .withConstantFeature(AuthorIsOffboarded, false) + .withConstantFeature(AuthorTakedownReasons, Seq.empty) + .withConstantFeature(AuthorHasDefaultProfileImage, false) + .withConstantFeature(AuthorAccountAge, Duration.Zero) + .withConstantFeature(AuthorIsNsfwUser, false) + .withConstantFeature(AuthorIsNsfwAdmin, false) + .withConstantFeature(AuthorIsVerified, false) + .withConstantFeature(AuthorIsBlueVerified, false) + } + + def authorUserLabels(author: User): Seq[Label] = + authorUserLabels(author.labels) + + def authorIsSuspended(authorId: Long): Stitch[Boolean] = + userSource.getSafety(authorId).map(safety => authorIsSuspended(Some(safety))) + + def authorIsSuspendedOpt(author: User): Option[Boolean] = { + authorIsSuspended.incr() + author.safety.map(_.suspended) + } + + private def authorIsSuspended(safety: Option[Safety]): Boolean = { + authorIsSuspended.incr() + safety.exists(_.suspended) + } + + def authorIsProtected(author: User): Boolean = + authorIsProtected(author.safety) + + def authorIsDeactivated(authorId: Long): Stitch[Boolean] = + userSource.getSafety(authorId).map(safety => authorIsDeactivated(Some(safety))) + + def authorIsDeactivatedOpt(author: User): Option[Boolean] = { + authorIsDeactivated.incr() + author.safety.map(_.deactivated) + } + + private def authorIsDeactivated(safety: Option[Safety]): Boolean = { + authorIsDeactivated.incr() + safety.exists(_.deactivated) + } + + def authorIsErased(author: User): Boolean = { + authorIsErased.incr() + author.safety.exists(_.erased) + } + + def authorIsOffboarded(authorId: Long): Stitch[Boolean] = { + userSource.getSafety(authorId).map(safety => authorIsOffboarded(Some(safety))) + } + + def authorIsNsfwUser(author: User): Boolean = { + authorIsNsfwUser(author.safety) + } + + def authorIsNsfwUser(authorId: Long): Stitch[Boolean] = { + userSource.getSafety(authorId).map(safety => authorIsNsfwUser(Some(safety))) + } + + def authorIsNsfwUser(safety: Option[Safety]): Boolean = { + authorIsNsfwUser.incr() + safety.exists(_.nsfwUser) + } + + def authorIsNsfwAdminOpt(author: User): Option[Boolean] = { + authorIsNsfwAdmin.incr() + author.safety.map(_.nsfwAdmin) + } + + def authorTakedownReasons(authorId: Long): Stitch[Seq[TakedownReason]] = { + authorTakedownReasons.incr() + userSource.getTakedownReasons(authorId) + } + + def authorHasDefaultProfileImage(authorId: Long): Stitch[Boolean] = + userSource.getProfile(authorId).map(profile => authorHasDefaultProfileImage(Some(profile))) + + def authorAccountAge(authorId: Long): Stitch[Duration] = + userSource.getCreatedAtMsec(authorId).map(authorAccountAgeFromTimestamp) + + def authorIsVerified(authorId: Long): Stitch[Boolean] = + userSource.getSafety(authorId).map(safety => authorIsVerified(Some(safety))) + + def authorIsVerifiedOpt(author: User): Option[Boolean] = { + authorIsVerified.incr() + author.safety.map(_.verified) + } + + private def authorIsVerified(safety: Option[Safety]): Boolean = { + authorIsVerified.incr() + safety.exists(_.verified) + } + + def authorScreenName(author: User): Option[String] = { + authorScreenName.incr() + author.profile.map(_.screenName) + } + + def authorScreenName(authorId: Long): Stitch[String] = { + authorScreenName.incr() + userSource.getProfile(authorId).map(profile => profile.screenName) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/BUILD new file mode 100644 index 000000000..9da789b38 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/BUILD @@ -0,0 +1,22 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "src/scala/com/twitter/search/blender/services/strato", + "src/thrift/com/twitter/content-health/sensitivemediasettings:sensitivemediasettings-scala", + "src/thrift/com/twitter/gizmoduck:user-thrift-scala", + "stitch/stitch-core", + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/common/src/main/scala/com/twitter/visibility/common/user_result", + "visibility/common/src/main/thrift/com/twitter/visibility:action-scala", + "visibility/lib/src/main/resources/config", + "visibility/lib/src/main/scala/com/twitter/visibility", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/common/blender", + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/common/search", + "visibility/lib/src/main/thrift/com/twitter/visibility/context:vf-context-scala", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/QuotedTweetFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/QuotedTweetFeatures.scala new file mode 100644 index 000000000..aac96d26f --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/QuotedTweetFeatures.scala @@ -0,0 +1,52 @@ +package com.twitter.visibility.builder.users + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.features.AuthorBlocksOuterAuthor +import com.twitter.visibility.features.OuterAuthorFollowsAuthor +import com.twitter.visibility.features.OuterAuthorId +import com.twitter.visibility.features.OuterAuthorIsInnerAuthor + +class QuotedTweetFeatures( + relationshipFeatures: RelationshipFeatures, + statsReceiver: StatsReceiver) { + + private[this] val scopedStatsReceiver = statsReceiver.scope("quoted_tweet_features") + + private[this] val requests = scopedStatsReceiver.counter("requests") + + private[this] val outerAuthorIdStat = + scopedStatsReceiver.scope(OuterAuthorId.name).counter("requests") + private[this] val authorBlocksOuterAuthor = + scopedStatsReceiver.scope(AuthorBlocksOuterAuthor.name).counter("requests") + private[this] val outerAuthorFollowsAuthor = + scopedStatsReceiver.scope(OuterAuthorFollowsAuthor.name).counter("requests") + private[this] val outerAuthorIsInnerAuthor = + scopedStatsReceiver.scope(OuterAuthorIsInnerAuthor.name).counter("requests") + + def forOuterAuthor( + outerAuthorId: Long, + innerAuthorId: Long + ): FeatureMapBuilder => FeatureMapBuilder = { + requests.incr() + + outerAuthorIdStat.incr() + authorBlocksOuterAuthor.incr() + outerAuthorFollowsAuthor.incr() + outerAuthorIsInnerAuthor.incr() + + val viewer = Some(outerAuthorId) + + _.withConstantFeature(OuterAuthorId, outerAuthorId) + .withFeature( + AuthorBlocksOuterAuthor, + relationshipFeatures.authorBlocksViewer(innerAuthorId, viewer)) + .withFeature( + OuterAuthorFollowsAuthor, + relationshipFeatures.viewerFollowsAuthor(innerAuthorId, viewer)) + .withConstantFeature( + OuterAuthorIsInnerAuthor, + innerAuthorId == outerAuthorId + ) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/RelationshipFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/RelationshipFeatures.scala new file mode 100644 index 000000000..9795e9408 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/RelationshipFeatures.scala @@ -0,0 +1,176 @@ +package com.twitter.visibility.builder.users + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.gizmoduck.thriftscala.User +import com.twitter.stitch.Stitch +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.common.UserId +import com.twitter.visibility.common.UserRelationshipSource +import com.twitter.visibility.features._ + +class RelationshipFeatures( + userRelationshipSource: UserRelationshipSource, + statsReceiver: StatsReceiver) { + + private[this] val scopedStatsReceiver = statsReceiver.scope("relationship_features") + + private[this] val requests = scopedStatsReceiver.counter("requests") + + private[this] val authorFollowsViewer = + scopedStatsReceiver.scope(AuthorFollowsViewer.name).counter("requests") + private[this] val viewerFollowsAuthor = + scopedStatsReceiver.scope(ViewerFollowsAuthor.name).counter("requests") + private[this] val authorBlocksViewer = + scopedStatsReceiver.scope(AuthorBlocksViewer.name).counter("requests") + private[this] val viewerBlocksAuthor = + scopedStatsReceiver.scope(ViewerBlocksAuthor.name).counter("requests") + private[this] val authorMutesViewer = + scopedStatsReceiver.scope(AuthorMutesViewer.name).counter("requests") + private[this] val viewerMutesAuthor = + scopedStatsReceiver.scope(ViewerMutesAuthor.name).counter("requests") + private[this] val authorHasReportedViewer = + scopedStatsReceiver.scope(AuthorReportsViewerAsSpam.name).counter("requests") + private[this] val viewerHasReportedAuthor = + scopedStatsReceiver.scope(ViewerReportsAuthorAsSpam.name).counter("requests") + private[this] val viewerMutesRetweetsFromAuthor = + scopedStatsReceiver.scope(ViewerMutesRetweetsFromAuthor.name).counter("requests") + + def forAuthorId( + authorId: Long, + viewerId: Option[Long] + ): FeatureMapBuilder => FeatureMapBuilder = { + requests.incr() + + _.withFeature(AuthorFollowsViewer, authorFollowsViewer(authorId, viewerId)) + .withFeature(ViewerFollowsAuthor, viewerFollowsAuthor(authorId, viewerId)) + .withFeature(AuthorBlocksViewer, authorBlocksViewer(authorId, viewerId)) + .withFeature(ViewerBlocksAuthor, viewerBlocksAuthor(authorId, viewerId)) + .withFeature(AuthorMutesViewer, authorMutesViewer(authorId, viewerId)) + .withFeature(ViewerMutesAuthor, viewerMutesAuthor(authorId, viewerId)) + .withFeature(AuthorReportsViewerAsSpam, authorHasReportedViewer(authorId, viewerId)) + .withFeature(ViewerReportsAuthorAsSpam, viewerHasReportedAuthor(authorId, viewerId)) + .withFeature(ViewerMutesRetweetsFromAuthor, viewerMutesRetweetsFromAuthor(authorId, viewerId)) + } + + def forNoAuthor(): FeatureMapBuilder => FeatureMapBuilder = { + _.withConstantFeature(AuthorFollowsViewer, false) + .withConstantFeature(ViewerFollowsAuthor, false) + .withConstantFeature(AuthorBlocksViewer, false) + .withConstantFeature(ViewerBlocksAuthor, false) + .withConstantFeature(AuthorMutesViewer, false) + .withConstantFeature(ViewerMutesAuthor, false) + .withConstantFeature(AuthorReportsViewerAsSpam, false) + .withConstantFeature(ViewerReportsAuthorAsSpam, false) + .withConstantFeature(ViewerMutesRetweetsFromAuthor, false) + } + + def forAuthor(author: User, viewerId: Option[Long]): FeatureMapBuilder => FeatureMapBuilder = { + requests.incr() + + + _.withFeature(AuthorFollowsViewer, authorFollowsViewer(author, viewerId)) + .withFeature(ViewerFollowsAuthor, viewerFollowsAuthor(author, viewerId)) + .withFeature(AuthorBlocksViewer, authorBlocksViewer(author, viewerId)) + .withFeature(ViewerBlocksAuthor, viewerBlocksAuthor(author, viewerId)) + .withFeature(AuthorMutesViewer, authorMutesViewer(author, viewerId)) + .withFeature(ViewerMutesAuthor, viewerMutesAuthor(author, viewerId)) + .withFeature(AuthorReportsViewerAsSpam, authorHasReportedViewer(author.id, viewerId)) + .withFeature(ViewerReportsAuthorAsSpam, viewerHasReportedAuthor(author.id, viewerId)) + .withFeature(ViewerMutesRetweetsFromAuthor, viewerMutesRetweetsFromAuthor(author, viewerId)) + } + + def viewerFollowsAuthor(authorId: UserId, viewerId: Option[UserId]): Stitch[Boolean] = + ViewerVerbsAuthor(authorId, viewerId, userRelationshipSource.follows, viewerFollowsAuthor) + + def viewerFollowsAuthor(author: User, viewerId: Option[UserId]): Stitch[Boolean] = + ViewerVerbsAuthor( + author, + viewerId, + p => p.following, + userRelationshipSource.follows, + viewerFollowsAuthor) + + def authorFollowsViewer(authorId: UserId, viewerId: Option[UserId]): Stitch[Boolean] = + AuthorVerbsViewer(authorId, viewerId, userRelationshipSource.follows, authorFollowsViewer) + + def authorFollowsViewer(author: User, viewerId: Option[UserId]): Stitch[Boolean] = + AuthorVerbsViewer( + author, + viewerId, + p => p.followedBy, + userRelationshipSource.follows, + authorFollowsViewer) + + def viewerBlocksAuthor(authorId: UserId, viewerId: Option[UserId]): Stitch[Boolean] = + ViewerVerbsAuthor(authorId, viewerId, userRelationshipSource.blocks, viewerBlocksAuthor) + + def viewerBlocksAuthor(author: User, viewerId: Option[UserId]): Stitch[Boolean] = + ViewerVerbsAuthor( + author, + viewerId, + p => p.blocking, + userRelationshipSource.blocks, + viewerBlocksAuthor) + + def authorBlocksViewer(authorId: UserId, viewerId: Option[UserId]): Stitch[Boolean] = + ViewerVerbsAuthor(authorId, viewerId, userRelationshipSource.blockedBy, authorBlocksViewer) + + def authorBlocksViewer(author: User, viewerId: Option[UserId]): Stitch[Boolean] = + ViewerVerbsAuthor( + author, + viewerId, + p => p.blockedBy, + userRelationshipSource.blockedBy, + authorBlocksViewer) + + def viewerMutesAuthor(authorId: UserId, viewerId: Option[UserId]): Stitch[Boolean] = + ViewerVerbsAuthor(authorId, viewerId, userRelationshipSource.mutes, viewerMutesAuthor) + + def viewerMutesAuthor(author: User, viewerId: Option[UserId]): Stitch[Boolean] = + ViewerVerbsAuthor( + author, + viewerId, + p => p.muting, + userRelationshipSource.mutes, + viewerMutesAuthor) + + def authorMutesViewer(authorId: UserId, viewerId: Option[UserId]): Stitch[Boolean] = + ViewerVerbsAuthor(authorId, viewerId, userRelationshipSource.mutedBy, authorMutesViewer) + + def authorMutesViewer(author: User, viewerId: Option[UserId]): Stitch[Boolean] = + ViewerVerbsAuthor( + author, + viewerId, + p => p.mutedBy, + userRelationshipSource.mutedBy, + authorMutesViewer) + + def viewerHasReportedAuthor(authorId: UserId, viewerId: Option[UserId]): Stitch[Boolean] = + ViewerVerbsAuthor( + authorId, + viewerId, + userRelationshipSource.reportsAsSpam, + viewerHasReportedAuthor) + + def authorHasReportedViewer(authorId: UserId, viewerId: Option[UserId]): Stitch[Boolean] = + ViewerVerbsAuthor( + authorId, + viewerId, + userRelationshipSource.reportedAsSpamBy, + authorHasReportedViewer) + + def viewerMutesRetweetsFromAuthor(authorId: UserId, viewerId: Option[UserId]): Stitch[Boolean] = + ViewerVerbsAuthor( + authorId, + viewerId, + userRelationshipSource.noRetweetsFrom, + viewerMutesRetweetsFromAuthor) + + def viewerMutesRetweetsFromAuthor(author: User, viewerId: Option[UserId]): Stitch[Boolean] = + ViewerVerbsAuthor( + author, + viewerId, + p => p.noRetweetsFrom, + userRelationshipSource.noRetweetsFrom, + viewerMutesRetweetsFromAuthor) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/RelationshipVerbHelpers.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/RelationshipVerbHelpers.scala new file mode 100644 index 000000000..0e654a69f --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/RelationshipVerbHelpers.scala @@ -0,0 +1,79 @@ +package com.twitter.visibility.builder.users + +import com.twitter.finagle.stats.Counter +import com.twitter.gizmoduck.thriftscala.Perspective +import com.twitter.gizmoduck.thriftscala.User +import com.twitter.stitch.Stitch +import com.twitter.visibility.common.UserId + +case object ViewerVerbsAuthor { + def apply( + authorId: UserId, + viewerIdOpt: Option[UserId], + relationship: (UserId, UserId) => Stitch[Boolean], + relationshipCounter: Counter + ): Stitch[Boolean] = { + relationshipCounter.incr() + + viewerIdOpt match { + case Some(viewerId) => relationship(viewerId, authorId) + case _ => Stitch.False + } + } + + def apply( + author: User, + viewerId: Option[UserId], + checkPerspective: Perspective => Option[Boolean], + relationship: (UserId, UserId) => Stitch[Boolean], + relationshipCounter: Counter + ): Stitch[Boolean] = { + author.perspective match { + case Some(perspective) => + checkPerspective(perspective) match { + case Some(status) => + relationshipCounter.incr() + Stitch.value(status) + case None => + ViewerVerbsAuthor(author.id, viewerId, relationship, relationshipCounter) + } + case None => ViewerVerbsAuthor(author.id, viewerId, relationship, relationshipCounter) + } + } +} + +case object AuthorVerbsViewer { + + def apply( + authorId: UserId, + viewerIdOpt: Option[UserId], + relationship: (UserId, UserId) => Stitch[Boolean], + relationshipCounter: Counter + ): Stitch[Boolean] = { + relationshipCounter.incr() + + viewerIdOpt match { + case Some(viewerId) => relationship(authorId, viewerId) + case _ => Stitch.False + } + } + def apply( + author: User, + viewerId: Option[UserId], + checkPerspective: Perspective => Option[Boolean], + relationship: (UserId, UserId) => Stitch[Boolean], + relationshipCounter: Counter + ): Stitch[Boolean] = { + author.perspective match { + case Some(perspective) => + checkPerspective(perspective) match { + case Some(status) => + relationshipCounter.incr() + Stitch.value(status) + case None => + AuthorVerbsViewer(author.id, viewerId, relationship, relationshipCounter) + } + case None => AuthorVerbsViewer(author.id, viewerId, relationship, relationshipCounter) + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/SearchFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/SearchFeatures.scala new file mode 100644 index 000000000..7602f2788 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/SearchFeatures.scala @@ -0,0 +1,26 @@ +package com.twitter.visibility.builder.users + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.features._ +import com.twitter.visibility.context.thriftscala.SearchContext + +class SearchFeatures(statsReceiver: StatsReceiver) { + private[this] val scopedStatsReceiver = statsReceiver.scope("search_features") + private[this] val requests = scopedStatsReceiver.counter("requests") + private[this] val rawQueryCounter = + scopedStatsReceiver.scope(RawQuery.name).counter("requests") + + def forSearchContext( + searchContext: Option[SearchContext] + ): FeatureMapBuilder => FeatureMapBuilder = { builder => + requests.incr() + searchContext match { + case Some(context: SearchContext) => + rawQueryCounter.incr() + builder + .withConstantFeature(RawQuery, context.rawQuery) + case _ => builder + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/UserUnavailableFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/UserUnavailableFeatures.scala new file mode 100644 index 000000000..4a196fe5f --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/UserUnavailableFeatures.scala @@ -0,0 +1,145 @@ +package com.twitter.visibility.builder.users + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.common.user_result.UserVisibilityResultHelper +import com.twitter.visibility.features.AuthorBlocksViewer +import com.twitter.visibility.features.AuthorIsDeactivated +import com.twitter.visibility.features.AuthorIsErased +import com.twitter.visibility.features.AuthorIsOffboarded +import com.twitter.visibility.features.AuthorIsProtected +import com.twitter.visibility.features.AuthorIsSuspended +import com.twitter.visibility.features.AuthorIsUnavailable +import com.twitter.visibility.features.ViewerBlocksAuthor +import com.twitter.visibility.features.ViewerMutesAuthor +import com.twitter.visibility.models.UserUnavailableStateEnum + +case class UserUnavailableFeatures(statsReceiver: StatsReceiver) { + + private[this] val scopedStatsReceiver = statsReceiver.scope("user_unavailable_features") + private[this] val suspendedAuthorStats = scopedStatsReceiver.scope("suspended_author") + private[this] val deactivatedAuthorStats = scopedStatsReceiver.scope("deactivated_author") + private[this] val offboardedAuthorStats = scopedStatsReceiver.scope("offboarded_author") + private[this] val erasedAuthorStats = scopedStatsReceiver.scope("erased_author") + private[this] val protectedAuthorStats = scopedStatsReceiver.scope("protected_author") + private[this] val authorBlocksViewerStats = scopedStatsReceiver.scope("author_blocks_viewer") + private[this] val viewerBlocksAuthorStats = scopedStatsReceiver.scope("viewer_blocks_author") + private[this] val viewerMutesAuthorStats = scopedStatsReceiver.scope("viewer_mutes_author") + private[this] val unavailableStats = scopedStatsReceiver.scope("unavailable") + + def forState(state: UserUnavailableStateEnum): FeatureMapBuilder => FeatureMapBuilder = { + builder => + builder + .withConstantFeature(AuthorIsSuspended, isSuspended(state)) + .withConstantFeature(AuthorIsDeactivated, isDeactivated(state)) + .withConstantFeature(AuthorIsOffboarded, isOffboarded(state)) + .withConstantFeature(AuthorIsErased, isErased(state)) + .withConstantFeature(AuthorIsProtected, isProtected(state)) + .withConstantFeature(AuthorBlocksViewer, authorBlocksViewer(state)) + .withConstantFeature(ViewerBlocksAuthor, viewerBlocksAuthor(state)) + .withConstantFeature(ViewerMutesAuthor, viewerMutesAuthor(state)) + .withConstantFeature(AuthorIsUnavailable, isUnavailable(state)) + } + + private[this] def isSuspended(state: UserUnavailableStateEnum): Boolean = + state match { + case UserUnavailableStateEnum.Suspended => + suspendedAuthorStats.counter().incr() + true + case UserUnavailableStateEnum.Filtered(result) + if UserVisibilityResultHelper.isDropSuspendedAuthor(result) => + suspendedAuthorStats.counter().incr() + suspendedAuthorStats.counter("filtered").incr() + true + case _ => false + } + + private[this] def isDeactivated(state: UserUnavailableStateEnum): Boolean = + state match { + case UserUnavailableStateEnum.Deactivated => + deactivatedAuthorStats.counter().incr() + true + case _ => false + } + + private[this] def isOffboarded(state: UserUnavailableStateEnum): Boolean = + state match { + case UserUnavailableStateEnum.Offboarded => + offboardedAuthorStats.counter().incr() + true + case _ => false + } + + private[this] def isErased(state: UserUnavailableStateEnum): Boolean = + state match { + case UserUnavailableStateEnum.Erased => + erasedAuthorStats.counter().incr() + true + case _ => false + } + + private[this] def isProtected(state: UserUnavailableStateEnum): Boolean = + state match { + case UserUnavailableStateEnum.Protected => + protectedAuthorStats.counter().incr() + true + case UserUnavailableStateEnum.Filtered(result) + if UserVisibilityResultHelper.isDropProtectedAuthor(result) => + protectedAuthorStats.counter().incr() + protectedAuthorStats.counter("filtered").incr() + true + case _ => false + } + + private[this] def authorBlocksViewer(state: UserUnavailableStateEnum): Boolean = + state match { + case UserUnavailableStateEnum.AuthorBlocksViewer => + authorBlocksViewerStats.counter().incr() + true + case UserUnavailableStateEnum.Filtered(result) + if UserVisibilityResultHelper.isDropAuthorBlocksViewer(result) => + authorBlocksViewerStats.counter().incr() + authorBlocksViewerStats.counter("filtered").incr() + true + case _ => false + } + + private[this] def viewerBlocksAuthor(state: UserUnavailableStateEnum): Boolean = + state match { + case UserUnavailableStateEnum.ViewerBlocksAuthor => + viewerBlocksAuthorStats.counter().incr() + true + case UserUnavailableStateEnum.Filtered(result) + if UserVisibilityResultHelper.isDropViewerBlocksAuthor(result) => + viewerBlocksAuthorStats.counter().incr() + viewerBlocksAuthorStats.counter("filtered").incr() + true + case _ => false + } + + private[this] def viewerMutesAuthor(state: UserUnavailableStateEnum): Boolean = + state match { + case UserUnavailableStateEnum.ViewerMutesAuthor => + viewerMutesAuthorStats.counter().incr() + true + case UserUnavailableStateEnum.Filtered(result) + if UserVisibilityResultHelper.isDropViewerMutesAuthor(result) => + viewerMutesAuthorStats.counter().incr() + viewerMutesAuthorStats.counter("filtered").incr() + true + case _ => false + } + + private[this] def isUnavailable(state: UserUnavailableStateEnum): Boolean = + state match { + case UserUnavailableStateEnum.Unavailable => + unavailableStats.counter().incr() + true + case UserUnavailableStateEnum.Filtered(result) + if UserVisibilityResultHelper.isDropUnspecifiedAuthor(result) => + unavailableStats.counter().incr() + unavailableStats.counter("filtered").incr() + true + case _ => false + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/ViewerAdvancedFilteringFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/ViewerAdvancedFilteringFeatures.scala new file mode 100644 index 000000000..38b3106f0 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/ViewerAdvancedFilteringFeatures.scala @@ -0,0 +1,92 @@ +package com.twitter.visibility.builder.users + +import com.twitter.finagle.stats.Counter +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.gizmoduck.thriftscala.AdvancedFilters +import com.twitter.gizmoduck.thriftscala.MentionFilter +import com.twitter.stitch.NotFound +import com.twitter.stitch.Stitch +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.common.UserSource +import com.twitter.visibility.features.ViewerFiltersDefaultProfileImage +import com.twitter.visibility.features.ViewerFiltersNewUsers +import com.twitter.visibility.features.ViewerFiltersNoConfirmedEmail +import com.twitter.visibility.features.ViewerFiltersNoConfirmedPhone +import com.twitter.visibility.features.ViewerFiltersNotFollowedBy +import com.twitter.visibility.features.ViewerMentionFilter + +class ViewerAdvancedFilteringFeatures(userSource: UserSource, statsReceiver: StatsReceiver) { + private[this] val scopedStatsReceiver = statsReceiver.scope("viewer_advanced_filtering_features") + + private[this] val requests = scopedStatsReceiver.counter("requests") + + private[this] val viewerFiltersNoConfirmedEmail = + scopedStatsReceiver.scope(ViewerFiltersNoConfirmedEmail.name).counter("requests") + private[this] val viewerFiltersNoConfirmedPhone = + scopedStatsReceiver.scope(ViewerFiltersNoConfirmedPhone.name).counter("requests") + private[this] val viewerFiltersDefaultProfileImage = + scopedStatsReceiver.scope(ViewerFiltersDefaultProfileImage.name).counter("requests") + private[this] val viewerFiltersNewUsers = + scopedStatsReceiver.scope(ViewerFiltersNewUsers.name).counter("requests") + private[this] val viewerFiltersNotFollowedBy = + scopedStatsReceiver.scope(ViewerFiltersNotFollowedBy.name).counter("requests") + private[this] val viewerMentionFilter = + scopedStatsReceiver.scope(ViewerMentionFilter.name).counter("requests") + + def forViewerId(viewerId: Option[Long]): FeatureMapBuilder => FeatureMapBuilder = { + requests.incr() + + _.withFeature(ViewerFiltersNoConfirmedEmail, viewerFiltersNoConfirmedEmail(viewerId)) + .withFeature(ViewerFiltersNoConfirmedPhone, viewerFiltersNoConfirmedPhone(viewerId)) + .withFeature(ViewerFiltersDefaultProfileImage, viewerFiltersDefaultProfileImage(viewerId)) + .withFeature(ViewerFiltersNewUsers, viewerFiltersNewUsers(viewerId)) + .withFeature(ViewerFiltersNotFollowedBy, viewerFiltersNotFollowedBy(viewerId)) + .withFeature(ViewerMentionFilter, viewerMentionFilter(viewerId)) + } + + def viewerFiltersNoConfirmedEmail(viewerId: Option[Long]): Stitch[Boolean] = + viewerAdvancedFilters(viewerId, af => af.filterNoConfirmedEmail, viewerFiltersNoConfirmedEmail) + + def viewerFiltersNoConfirmedPhone(viewerId: Option[Long]): Stitch[Boolean] = + viewerAdvancedFilters(viewerId, af => af.filterNoConfirmedPhone, viewerFiltersNoConfirmedPhone) + + def viewerFiltersDefaultProfileImage(viewerId: Option[Long]): Stitch[Boolean] = + viewerAdvancedFilters( + viewerId, + af => af.filterDefaultProfileImage, + viewerFiltersDefaultProfileImage + ) + + def viewerFiltersNewUsers(viewerId: Option[Long]): Stitch[Boolean] = + viewerAdvancedFilters(viewerId, af => af.filterNewUsers, viewerFiltersNewUsers) + + def viewerFiltersNotFollowedBy(viewerId: Option[Long]): Stitch[Boolean] = + viewerAdvancedFilters(viewerId, af => af.filterNotFollowedBy, viewerFiltersNotFollowedBy) + + def viewerMentionFilter(viewerId: Option[Long]): Stitch[MentionFilter] = { + viewerMentionFilter.incr() + viewerId match { + case Some(id) => + userSource.getMentionFilter(id).handle { + case NotFound => + MentionFilter.Unfiltered + } + case _ => Stitch.value(MentionFilter.Unfiltered) + } + } + + private[this] def viewerAdvancedFilters( + viewerId: Option[Long], + advancedFilterCheck: AdvancedFilters => Boolean, + featureCounter: Counter + ): Stitch[Boolean] = { + featureCounter.incr() + + val advancedFilters = viewerId match { + case Some(id) => userSource.getAdvancedFilters(id) + case _ => Stitch.value(AdvancedFilters()) + } + + advancedFilters.map(advancedFilterCheck) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/ViewerFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/ViewerFeatures.scala new file mode 100644 index 000000000..4e97ce5d4 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/ViewerFeatures.scala @@ -0,0 +1,245 @@ +package com.twitter.visibility.builder.users + +import com.twitter.finagle.stats.Counter +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.gizmoduck.thriftscala.Label +import com.twitter.gizmoduck.thriftscala.Safety +import com.twitter.gizmoduck.thriftscala.UniversalQualityFiltering +import com.twitter.gizmoduck.thriftscala.User +import com.twitter.gizmoduck.thriftscala.UserType +import com.twitter.stitch.NotFound +import com.twitter.stitch.Stitch +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.common.UserId +import com.twitter.visibility.common.UserSource +import com.twitter.visibility.features._ +import com.twitter.visibility.interfaces.common.blender.BlenderVFRequestContext +import com.twitter.visibility.interfaces.common.search.SearchVFRequestContext +import com.twitter.visibility.models.UserAge +import com.twitter.visibility.models.ViewerContext + +class ViewerFeatures(userSource: UserSource, statsReceiver: StatsReceiver) { + private[this] val scopedStatsReceiver = statsReceiver.scope("viewer_features") + + private[this] val requests = scopedStatsReceiver.counter("requests") + + private[this] val viewerIdCount = + scopedStatsReceiver.scope(ViewerId.name).counter("requests") + private[this] val requestCountryCode = + scopedStatsReceiver.scope(RequestCountryCode.name).counter("requests") + private[this] val requestIsVerifiedCrawler = + scopedStatsReceiver.scope(RequestIsVerifiedCrawler.name).counter("requests") + private[this] val viewerUserLabels = + scopedStatsReceiver.scope(ViewerUserLabels.name).counter("requests") + private[this] val viewerIsDeactivated = + scopedStatsReceiver.scope(ViewerIsDeactivated.name).counter("requests") + private[this] val viewerIsProtected = + scopedStatsReceiver.scope(ViewerIsProtected.name).counter("requests") + private[this] val viewerIsSuspended = + scopedStatsReceiver.scope(ViewerIsSuspended.name).counter("requests") + private[this] val viewerRoles = + scopedStatsReceiver.scope(ViewerRoles.name).counter("requests") + private[this] val viewerCountryCode = + scopedStatsReceiver.scope(ViewerCountryCode.name).counter("requests") + private[this] val viewerAge = + scopedStatsReceiver.scope(ViewerAge.name).counter("requests") + private[this] val viewerHasUniversalQualityFilterEnabled = + scopedStatsReceiver.scope(ViewerHasUniversalQualityFilterEnabled.name).counter("requests") + private[this] val viewerIsSoftUserCtr = + scopedStatsReceiver.scope(ViewerIsSoftUser.name).counter("requests") + + def forViewerBlenderContext( + blenderContext: BlenderVFRequestContext, + viewerContext: ViewerContext + ): FeatureMapBuilder => FeatureMapBuilder = + forViewerContext(viewerContext) + .andThen( + _.withConstantFeature( + ViewerOptInBlocking, + blenderContext.userSearchSafetySettings.optInBlocking) + .withConstantFeature( + ViewerOptInFiltering, + blenderContext.userSearchSafetySettings.optInFiltering) + ) + + def forViewerSearchContext( + searchContext: SearchVFRequestContext, + viewerContext: ViewerContext + ): FeatureMapBuilder => FeatureMapBuilder = + forViewerContext(viewerContext) + .andThen( + _.withConstantFeature( + ViewerOptInBlocking, + searchContext.userSearchSafetySettings.optInBlocking) + .withConstantFeature( + ViewerOptInFiltering, + searchContext.userSearchSafetySettings.optInFiltering) + ) + + def forViewerContext(viewerContext: ViewerContext): FeatureMapBuilder => FeatureMapBuilder = + forViewerId(viewerContext.userId) + .andThen( + _.withConstantFeature(RequestCountryCode, requestCountryCode(viewerContext)) + ).andThen( + _.withConstantFeature(RequestIsVerifiedCrawler, requestIsVerifiedCrawler(viewerContext)) + ) + + def forViewerId(viewerId: Option[UserId]): FeatureMapBuilder => FeatureMapBuilder = { builder => + requests.incr() + + val builderWithFeatures = builder + .withConstantFeature(ViewerId, viewerId) + .withFeature(ViewerIsProtected, viewerIsProtected(viewerId)) + .withFeature( + ViewerHasUniversalQualityFilterEnabled, + viewerHasUniversalQualityFilterEnabled(viewerId) + ) + .withFeature(ViewerIsSuspended, viewerIsSuspended(viewerId)) + .withFeature(ViewerIsDeactivated, viewerIsDeactivated(viewerId)) + .withFeature(ViewerUserLabels, viewerUserLabels(viewerId)) + .withFeature(ViewerRoles, viewerRoles(viewerId)) + .withFeature(ViewerAge, viewerAgeInYears(viewerId)) + .withFeature(ViewerIsSoftUser, viewerIsSoftUser(viewerId)) + + viewerId match { + case Some(_) => + viewerIdCount.incr() + builderWithFeatures + .withFeature(ViewerCountryCode, viewerCountryCode(viewerId)) + + case _ => + builderWithFeatures + } + } + + def forViewerNoDefaults(viewerOpt: Option[User]): FeatureMapBuilder => FeatureMapBuilder = { + builder => + requests.incr() + + viewerOpt match { + case Some(viewer) => + builder + .withConstantFeature(ViewerId, viewer.id) + .withConstantFeature(ViewerIsProtected, viewerIsProtectedOpt(viewer)) + .withConstantFeature(ViewerIsSuspended, viewerIsSuspendedOpt(viewer)) + .withConstantFeature(ViewerIsDeactivated, viewerIsDeactivatedOpt(viewer)) + .withConstantFeature(ViewerCountryCode, viewerCountryCode(viewer)) + case None => + builder + .withConstantFeature(ViewerIsProtected, false) + .withConstantFeature(ViewerIsSuspended, false) + .withConstantFeature(ViewerIsDeactivated, false) + } + } + + private def checkSafetyValue( + viewerId: Option[UserId], + safetyCheck: Safety => Boolean, + featureCounter: Counter + ): Stitch[Boolean] = + viewerId match { + case Some(id) => + userSource.getSafety(id).map(safetyCheck).ensure { + featureCounter.incr() + } + case None => Stitch.False + } + + private def checkSafetyValue( + viewer: User, + safetyCheck: Safety => Boolean + ): Boolean = { + viewer.safety.exists(safetyCheck) + } + + def viewerIsProtected(viewerId: Option[UserId]): Stitch[Boolean] = + checkSafetyValue(viewerId, s => s.isProtected, viewerIsProtected) + + def viewerIsProtected(viewer: User): Boolean = + checkSafetyValue(viewer, s => s.isProtected) + + def viewerIsProtectedOpt(viewer: User): Option[Boolean] = + viewer.safety.map(_.isProtected) + + def viewerIsDeactivated(viewerId: Option[UserId]): Stitch[Boolean] = + checkSafetyValue(viewerId, s => s.deactivated, viewerIsDeactivated) + + def viewerIsDeactivated(viewer: User): Boolean = + checkSafetyValue(viewer, s => s.deactivated) + + def viewerIsDeactivatedOpt(viewer: User): Option[Boolean] = + viewer.safety.map(_.deactivated) + + def viewerHasUniversalQualityFilterEnabled(viewerId: Option[UserId]): Stitch[Boolean] = + checkSafetyValue( + viewerId, + s => s.universalQualityFiltering == UniversalQualityFiltering.Enabled, + viewerHasUniversalQualityFilterEnabled + ) + + def viewerUserLabels(viewerIdOpt: Option[UserId]): Stitch[Seq[Label]] = + viewerIdOpt match { + case Some(viewerId) => + userSource + .getLabels(viewerId).map(_.labels) + .handle { + case NotFound => Seq.empty + }.ensure { + viewerUserLabels.incr() + } + case _ => Stitch.value(Seq.empty) + } + + def viewerAgeInYears(viewerId: Option[UserId]): Stitch[UserAge] = + (viewerId match { + case Some(id) => + userSource + .getExtendedProfile(id).map(_.ageInYears) + .handle { + case NotFound => None + }.ensure { + viewerAge.incr() + } + case _ => Stitch.value(None) + }).map(UserAge) + + def viewerIsSoftUser(viewerId: Option[UserId]): Stitch[Boolean] = { + viewerId match { + case Some(id) => + userSource + .getUserType(id).map { userType => + userType == UserType.Soft + }.ensure { + viewerIsSoftUserCtr.incr() + } + case _ => Stitch.False + } + } + + def requestCountryCode(viewerContext: ViewerContext): Option[String] = { + requestCountryCode.incr() + viewerContext.requestCountryCode + } + + def requestIsVerifiedCrawler(viewerContext: ViewerContext): Boolean = { + requestIsVerifiedCrawler.incr() + viewerContext.isVerifiedCrawler + } + + def viewerCountryCode(viewerId: Option[UserId]): Stitch[String] = + viewerId match { + case Some(id) => + userSource + .getAccount(id).map(_.countryCode).flatMap { + case Some(countryCode) => Stitch.value(countryCode.toLowerCase) + case _ => Stitch.NotFound + }.ensure { + viewerCountryCode.incr() + } + + case _ => Stitch.NotFound + } + + def viewerCountryCode(viewer: User): Option[String] = + viewer.account.flatMap(_.countryCode) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/ViewerSearchSafetyFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/ViewerSearchSafetyFeatures.scala new file mode 100644 index 000000000..6cddcd74c --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/ViewerSearchSafetyFeatures.scala @@ -0,0 +1,49 @@ +package com.twitter.visibility.builder.users + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.stitch.Stitch +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.common.UserId +import com.twitter.visibility.common.UserSearchSafetySource +import com.twitter.visibility.features.ViewerId +import com.twitter.visibility.features.ViewerOptInBlocking +import com.twitter.visibility.features.ViewerOptInFiltering + +class ViewerSearchSafetyFeatures( + userSearchSafetySource: UserSearchSafetySource, + statsReceiver: StatsReceiver) { + private[this] val scopedStatsReceiver = statsReceiver.scope("viewer_search_safety_features") + + private[this] val requests = scopedStatsReceiver.counter("requests") + + private[this] val viewerOptInBlockingRequests = + scopedStatsReceiver.scope(ViewerOptInBlocking.name).counter("requests") + + private[this] val viewerOptInFilteringRequests = + scopedStatsReceiver.scope(ViewerOptInFiltering.name).counter("requests") + + def forViewerId(viewerId: Option[UserId]): FeatureMapBuilder => FeatureMapBuilder = { builder => + requests.incr() + + builder + .withConstantFeature(ViewerId, viewerId) + .withFeature(ViewerOptInBlocking, viewerOptInBlocking(viewerId)) + .withFeature(ViewerOptInFiltering, viewerOptInFiltering(viewerId)) + } + + def viewerOptInBlocking(viewerId: Option[UserId]): Stitch[Boolean] = { + viewerOptInBlockingRequests.incr() + viewerId match { + case Some(userId) => userSearchSafetySource.optInBlocking(userId) + case _ => Stitch.False + } + } + + def viewerOptInFiltering(viewerId: Option[UserId]): Stitch[Boolean] = { + viewerOptInFilteringRequests.incr() + viewerId match { + case Some(userId) => userSearchSafetySource.optInFiltering(userId) + case _ => Stitch.False + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/ViewerSensitiveMediaSettingsFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/ViewerSensitiveMediaSettingsFeatures.scala new file mode 100644 index 000000000..31b982886 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/builder/users/ViewerSensitiveMediaSettingsFeatures.scala @@ -0,0 +1,41 @@ +package com.twitter.visibility.builder.users + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.stitch.Stitch +import com.twitter.stitch.NotFound +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.common.UserId +import com.twitter.visibility.common.UserSensitiveMediaSettingsSource +import com.twitter.visibility.features.ViewerId +import com.twitter.visibility.features.ViewerSensitiveMediaSettings +import com.twitter.visibility.models.UserSensitiveMediaSettings + + +class ViewerSensitiveMediaSettingsFeatures( + userSensitiveMediaSettingsSource: UserSensitiveMediaSettingsSource, + statsReceiver: StatsReceiver) { + private[this] val scopedStatsReceiver = + statsReceiver.scope("viewer_sensitive_media_settings_features") + + private[this] val requests = scopedStatsReceiver.counter("requests") + + def forViewerId(viewerId: Option[UserId]): FeatureMapBuilder => FeatureMapBuilder = { builder => + requests.incr() + + builder + .withConstantFeature(ViewerId, viewerId) + .withFeature(ViewerSensitiveMediaSettings, viewerSensitiveMediaSettings(viewerId)) + } + + def viewerSensitiveMediaSettings(viewerId: Option[UserId]): Stitch[UserSensitiveMediaSettings] = { + (viewerId match { + case Some(userId) => + userSensitiveMediaSettingsSource + .userSensitiveMediaSettings(userId) + .handle { + case NotFound => None + } + case _ => Stitch.value(None) + }).map(UserSensitiveMediaSettings) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/BUILD new file mode 100644 index 000000000..b0562b356 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/BUILD @@ -0,0 +1,19 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "abdecider/src/main/scala", + "configapi/configapi-abdecider", + "configapi/configapi-core", + "configapi/configapi-featureswitches:v2", + "decider", + "featureswitches/featureswitches-core/src/main/scala", + "finagle/finagle-stats", + "servo/decider/src/main/scala", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/configs", + "visibility/lib/src/main/scala/com/twitter/visibility/models", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/ConfigBuilder.scala b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/ConfigBuilder.scala new file mode 100644 index 000000000..df634d400 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/ConfigBuilder.scala @@ -0,0 +1,43 @@ +package com.twitter.visibility.configapi + +import com.twitter.decider.Decider +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.logging.Logger +import com.twitter.servo.decider.DeciderGateBuilder +import com.twitter.timelines.configapi.CompositeConfig +import com.twitter.timelines.configapi.Config +import com.twitter.util.Memoize +import com.twitter.visibility.configapi.configs.VisibilityDeciders +import com.twitter.visibility.configapi.configs.VisibilityExperimentsConfig +import com.twitter.visibility.configapi.configs.VisibilityFeatureSwitches +import com.twitter.visibility.models.SafetyLevel + +object ConfigBuilder { + + def apply(statsReceiver: StatsReceiver, decider: Decider, logger: Logger): ConfigBuilder = { + val deciderGateBuilder: DeciderGateBuilder = + new DeciderGateBuilder(decider) + + new ConfigBuilder( + deciderGateBuilder, + statsReceiver, + logger + ) + } +} + +class ConfigBuilder( + deciderGateBuilder: DeciderGateBuilder, + statsReceiver: StatsReceiver, + logger: Logger) { + + def buildMemoized: SafetyLevel => Config = Memoize(build) + + def build(safetyLevel: SafetyLevel): Config = { + new CompositeConfig( + VisibilityExperimentsConfig.config(safetyLevel) :+ + VisibilityDeciders.config(deciderGateBuilder, logger, statsReceiver, safetyLevel) :+ + VisibilityFeatureSwitches.config(statsReceiver, logger) + ) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/VisibilityParams.scala b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/VisibilityParams.scala new file mode 100644 index 000000000..e45485b52 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/VisibilityParams.scala @@ -0,0 +1,61 @@ +package com.twitter.visibility.configapi + +import com.twitter.abdecider.LoggingABDecider +import com.twitter.decider.Decider +import com.twitter.featureswitches.v2.FeatureSwitches +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.logging.Logger +import com.twitter.servo.util.MemoizingStatsReceiver +import com.twitter.timelines.configapi.Params +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.UnitOfDiversion +import com.twitter.visibility.models.ViewerContext + +object VisibilityParams { + def apply( + log: Logger, + statsReceiver: StatsReceiver, + decider: Decider, + abDecider: LoggingABDecider, + featureSwitches: FeatureSwitches + ): VisibilityParams = + new VisibilityParams(log, statsReceiver, decider, abDecider, featureSwitches) +} + +class VisibilityParams( + log: Logger, + statsReceiver: StatsReceiver, + decider: Decider, + abDecider: LoggingABDecider, + featureSwitches: FeatureSwitches) { + + private[this] val contextFactory = new VisibilityRequestContextFactory( + abDecider, + featureSwitches + ) + + private[this] val configBuilder = ConfigBuilder(statsReceiver.scope("config"), decider, log) + + private[this] val paramStats: MemoizingStatsReceiver = new MemoizingStatsReceiver( + statsReceiver.scope("configapi_params")) + + def apply( + viewerContext: ViewerContext, + safetyLevel: SafetyLevel, + unitsOfDiversion: Seq[UnitOfDiversion] = Seq.empty + ): Params = { + val config = configBuilder.build(safetyLevel) + val requestContext = contextFactory(viewerContext, safetyLevel, unitsOfDiversion) + config.apply(requestContext, paramStats) + } + + def memoized( + viewerContext: ViewerContext, + safetyLevel: SafetyLevel, + unitsOfDiversion: Seq[UnitOfDiversion] = Seq.empty + ): Params = { + val config = configBuilder.buildMemoized(safetyLevel) + val requestContext = contextFactory(viewerContext, safetyLevel, unitsOfDiversion) + config.apply(requestContext, paramStats) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/VisibilityRequestContext.scala b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/VisibilityRequestContext.scala new file mode 100644 index 000000000..9e9564392 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/VisibilityRequestContext.scala @@ -0,0 +1,14 @@ +package com.twitter.visibility.configapi + +import com.twitter.timelines.configapi._ + +case class VisibilityRequestContext( + userId: Option[Long], + guestId: Option[Long], + experimentContext: ExperimentContext = NullExperimentContext, + featureContext: FeatureContext = NullFeatureContext) + extends BaseRequestContext + with WithUserId + with WithGuestId + with WithExperimentContext + with WithFeatureContext diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/VisibilityRequestContextFactory.scala b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/VisibilityRequestContextFactory.scala new file mode 100644 index 000000000..1d389d68e --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/VisibilityRequestContextFactory.scala @@ -0,0 +1,64 @@ +package com.twitter.visibility.configapi + +import com.twitter.abdecider.LoggingABDecider +import com.twitter.featureswitches.FSRecipient +import com.twitter.featureswitches.v2.FeatureSwitches +import com.twitter.timelines.configapi.abdecider.UserRecipientExperimentContextFactory +import com.twitter.timelines.configapi.featureswitches.v2.FeatureSwitchResultsFeatureContext +import com.twitter.timelines.configapi.FeatureContext +import com.twitter.timelines.configapi.NullExperimentContext +import com.twitter.timelines.configapi.UseFeatureContextExperimentContext +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.UnitOfDiversion +import com.twitter.visibility.models.ViewerContext + +class VisibilityRequestContextFactory( + loggingABDecider: LoggingABDecider, + featureSwitches: FeatureSwitches) { + private val userExperimentContextFactory = new UserRecipientExperimentContextFactory( + loggingABDecider + ) + private[this] def getFeatureContext( + context: ViewerContext, + safetyLevel: SafetyLevel, + unitsOfDiversion: Seq[UnitOfDiversion] + ): FeatureContext = { + val uodCustomFields = unitsOfDiversion.map(_.apply) + val recipient = FSRecipient( + userId = context.userId, + guestId = context.guestId, + userAgent = context.fsUserAgent, + clientApplicationId = context.clientApplicationId, + countryCode = context.requestCountryCode, + deviceId = context.deviceId, + languageCode = context.requestLanguageCode, + isTwoffice = Some(context.isTwOffice), + userRoles = context.userRoles, + ).withCustomFields(("safety_level", safetyLevel.name), uodCustomFields: _*) + + val results = featureSwitches.matchRecipient(recipient) + new FeatureSwitchResultsFeatureContext(results) + } + + def apply( + context: ViewerContext, + safetyLevel: SafetyLevel, + unitsOfDiversion: Seq[UnitOfDiversion] = Seq.empty + ): VisibilityRequestContext = { + val experimentContextBase = + context.userId + .map(userId => userExperimentContextFactory.apply(userId)).getOrElse(NullExperimentContext) + + val featureContext = getFeatureContext(context, safetyLevel, unitsOfDiversion) + + val experimentContext = + UseFeatureContextExperimentContext(experimentContextBase, featureContext) + + VisibilityRequestContext( + context.userId, + context.guestId, + experimentContext, + featureContext + ) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/BUILD new file mode 100644 index 000000000..89e1b7c83 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/BUILD @@ -0,0 +1,17 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "configapi/configapi-core", + "configapi/configapi-decider", + "decider", + "finagle/finagle-stats", + "servo/decider", + "servo/decider/src/main/scala", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/params", + "visibility/lib/src/main/scala/com/twitter/visibility/models", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/DeciderKey.scala b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/DeciderKey.scala new file mode 100644 index 000000000..9fefb4154 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/DeciderKey.scala @@ -0,0 +1,1064 @@ +package com.twitter.visibility.configapi.configs + +import com.twitter.servo.decider.DeciderKeyEnum + +private[visibility] object DeciderKey extends DeciderKeyEnum { + + val EnableAllSubscribedListsSafetyLevel: Value = Value( + "visibility_library_enable_all_subscribed_lists_safety_level" + ) + val EnableAdsBusinessSettingsSafetyLevel: Value = Value( + "visibility_library_enable_ads_business_settings_safety_level" + ) + val EnableAdsCampaignSafetyLevel: Value = Value( + "visibility_library_enable_ads_campaign_safety_level" + ) + val EnableAdsManagerSafetyLevel: Value = Value( + "visibility_library_enable_ads_manager_safety_level" + ) + val EnableAdsReportingDashboardSafetyLevel: Value = Value( + "visibility_library_enable_ads_reporting_dashboard_safety_level" + ) + val EnableAppealsSafetyLevel: Value = Value( + "visibility_library_enable_appeals_safety_level" + ) + val EnableArticleTweetTimelineSafetyLevel: Value = Value( + "visibility_library_enable_article_tweet_timeline_safety_level" + ) + val EnableBaseQig: Value = Value( + "visibility_library_enable_base_qig_safety_level" + ) + val EnableBirdwatchNoteAuthorSafetyLevel: Value = Value( + "visibility_library_enable_birdwatch_note_author_safety_level" + ) + val EnableBirdwatchNoteTweetsTimelineSafetyLevel: Value = Value( + "visibility_library_enable_birdwatch_note_tweets_timeline_safety_level" + ) + val EnableBirdwatchNeedsYourHelpNotificationsSafetyLevel: Value = Value( + "visibility_library_enable_birdwatch_needs_your_help_notifications_safety_level" + ) + val EnableBlockMuteUsersTimelineSafetyLevel: Value = Value( + "visibility_library_enable_block_mute_users_timeline_safety_level" + ) + val EnableBrandSafetySafetyLevel: Value = Value( + "visibility_library_enable_brand_safety_safety_level" + ) + val EnableCardPollVotingSafetyLevel: Value = Value( + "visibility_library_enable_card_poll_voting_safety_level" + ) + val EnableCardsServiceSafetyLevel: Value = Value( + "visibility_library_enable_cards_service_safety_level" + ) + val EnableCommunitiesSafetyLevel: Value = Value( + "visibility_library_enable_communities_safety_level" + ) + val EnableContentControlToolInstallSafetyLevel: Value = Value( + "visibility_library_enable_content_control_tool_install_safety_level" + ) + val EnableConversationFocalPrehydrationSafetyLevel: Value = Value( + "visibility_library_enable_conversation_focal_prehydration_safety_level" + ) + val EnableConversationFocalTweetSafetyLevel: Value = Value( + "visibility_library_enable_conversation_focal_tweet_safety_level" + ) + val EnableConversationInjectedTweetSafetyLevel: Value = Value( + "visibility_library_enable_conversation_injected_tweet_safety_level" + ) + val EnableConversationReplySafetyLevel: Value = Value( + "visibility_library_enable_conversation_reply_safety_level" + ) + val EnableCuratedTrendsRepresentativeTweet: Value = Value( + "visibility_library_curated_trends_representative_tweet" + ) + val EnableCurationPolicyViolations: Value = Value( + "visibility_library_curation_policy_violations" + ) + val EnableDeprecatedSafetyLevelSafetyLevel: Value = Value( + "visibility_library_enable_deprecated_safety_level_safety_level" + ) + val EnableDevPlatformGetListTweetsSafetyLevel: Value = Value( + "visibility_library_enable_dev_platform_get_list_tweets_safety_level" + ) + val EnableDesFollowingAndFollowersUserListSafetyLevel: Value = Value( + "visibility_library_enable_des_following_and_followers_user_list_safety_level" + ) + val EnableDesHomeTimelineSafetyLevel: Value = Value( + "visibility_library_enable_des_home_timeline_safety_level" + ) + val EnableDesQuoteTweetTimelineSafetyLevel: Value = Value( + "visibility_library_enable_des_quote_tweet_timeline_safety_level" + ) + val EnableDesRealtimeSafetyLevel: Value = Value( + "visibility_library_enable_des_realtime_safety_level" + ) + val EnableDesRealtimeSpamEnrichmentSafetyLevel: Value = Value( + "visibility_library_enable_des_realtime_spam_enrichment_safety_level" + ) + val EnableDesRealtimeTweetFilterSafetyLevel: Value = Value( + "visibility_library_enable_des_realtime_tweet_filter_safety_level" + ) + val EnableDesRetweetingUsersSafetyLevel: Value = Value( + "visibility_library_enable_des_retweeting_users_safety_level" + ) + val EnableDesTweetDetailSafetyLevel: Value = Value( + "visibility_library_enable_des_tweet_detail_safety_level" + ) + val EnableDesTweetLikingUsersSafetyLevel: Value = Value( + "visibility_library_enable_des_tweet_liking_users_safety_level" + ) + val EnableDesUserBookmarksSafetyLevel: Value = Value( + "visibility_library_enable_des_user_bookmarks_safety_level" + ) + val EnableDesUserLikedTweetsSafetyLevel: Value = Value( + "visibility_library_enable_des_user_liked_tweets_safety_level" + ) + val EnableDesUserMentionsSafetyLevel: Value = Value( + "visibility_library_enable_des_user_mentions_safety_level" + ) + val EnableDesUserTweetsSafetyLevel: Value = Value( + "visibility_library_enable_des_user_tweets_safety_level" + ) + val EnableDevPlatformComplianceStreamSafetyLevel: Value = Value( + "visibility_library_enable_dev_platform_compliance_stream_safety_level" + ) + val EnableDirectMessagesSafetyLevel: Value = Value( + "visibility_library_enable_direct_messages_safety_level" + ) + val EnableDirectMessagesConversationListSafetyLevel: Value = Value( + "visibility_library_enable_direct_messages_conversation_list_safety_level" + ) + val EnableDirectMessagesConversationTimelineSafetyLevel: Value = Value( + "visibility_library_enable_direct_messages_conversation_timeline_safety_level" + ) + val EnableDirectMessagesInboxSafetyLevel: Value = Value( + "visibility_library_enable_direct_messages_inbox_safety_level" + ) + val EnableDirectMessagesMutedUsersSafetyLevel: Value = Value( + "visibility_library_enable_direct_messages_muted_users_safety_level" + ) + val EnableDirectMessagesPinnedSafetyLevel: Value = Value( + "visibility_library_enable_direct_messages_pinned_safety_level" + ) + val EnableDirectMessagesSearchSafetyLevel: Value = Value( + "visibility_library_enable_direct_messages_search_safety_level" + ) + val EnableElevatedQuoteTweetTimelineSafetyLevel: Value = Value( + "visibility_library_enable_elevated_quote_tweet_timeline_safety_level" + ) + val EnableEmbeddedTweetSafetyLevel: Value = Value( + "visibility_library_enable_embedded_tweet_safety_level" + ) + val EnableEmbedsPublicInterestNoticeSafetyLevel: Value = Value( + "visibility_library_enable_embeds_public_interest_notice_safety_level" + ) + val EnableEmbedTweetMarkupSafetyLevel: Value = Value( + "visibility_library_enable_embed_tweet_markup_safety_level" + ) + val EnableWritePathLimitedActionsEnforcementSafetyLevel: Value = Value( + "visibility_library_enable_write_path_limited_actions_enforcement_safety_level" + ) + val EnableFilterDefaultSafetyLevel: Value = Value( + "visibility_library_enable_filter_default_safety_level" + ) + val EnableFilterNoneSafetyLevel: Value = Value( + "visibility_library_enable_filter_none_safety_level" + ) + + val EnableFilterAllSafetyLevel: Value = Value( + "visibility_library_enable_filter_all_safety_level" + ) + val EnableFilterAllPlaceholderSafetyLevel: Value = Value( + "visibility_library_enable_filter_all_placeholder_safety_level" + ) + + val EnableFollowedTopicsTimelineSafetyLevel: Value = Value( + "visibility_library_enable_followed_topics_timeline_safety_level" + ) + + val EnableFollowerConnectionsSafetyLevel: Value = Value( + "visibility_library_enable_follower_connections_safety_level" + ) + val EnableFollowingAndFollowersUserListSafetyLevel: Value = Value( + "visibility_library_enable_following_and_followers_user_list_safety_level" + ) + + val EnableForDevelopmentOnlySafetyLevel: Value = Value( + "visibility_library_enable_for_development_only_safety_level" + ) + + val EnableFriendsFollowingListSafetyLevel: Value = Value( + "visibility_library_enable_friends_following_list_safety_level" + ) + + val EnableGraphqlDefaultSafetyLevel: Value = Value( + "visibility_library_enable_graphql_default_safety_level" + ) + + val EnableGryphonDecksAndColumnsSafetyLevel: Value = Value( + "visibility_library_enable_gryphon_decks_and_columns_safety_level" + ) + + val EnableHumanizationNudgeSafetyLevel: Value = Value( + "visibility_library_enable_humanization_nudge_safety_level" + ) + + val EnableKitchenSinkDevelopmentSafetyLevel: Value = Value( + "visibility_library_enable_kitchen_sink_development_safety_level" + ) + + val EnableListHeaderSafetyLevel: Value = Value( + "visibility_library_enable_list_header_safety_level" + ) + + val EnableListMembershipsSafetyLevel: Value = Value( + "visibility_library_enable_list_memberships_safety_level" + ) + + val EnableListOwnershipsSafetyLevel: Value = Value( + "visibility_library_enable_list_ownerships_safety_level" + ) + + val EnableListRecommendationsSafetyLevel: Value = Value( + "visibility_library_enable_list_recommendations_safety_level" + ) + + val EnableListSearchSafetyLevel: Value = Value( + "visibility_library_enable_list_search_safety_level" + ) + + val EnableListSubscriptionsSafetyLevel: Value = Value( + "visibility_library_enable_list_subscriptions_safety_level" + ) + + val EnableLivePipelineEngagementCountsSafetyLevel: Value = Value( + "visibility_library_enable_live_pipeline_engagement_counts_safety_level" + ) + val EnableLiveVideoTimelineSafetyLevel: Value = Value( + "visibility_library_enable_live_video_timeline_safety_level" + ) + val EnableMagicRecsAggressiveSafetyLevel: Value = Value( + "visibility_library_enable_magic_recs_aggressive_safety_level" + ) + val EnableMagicRecsAggressiveV2SafetyLevel: Value = Value( + "visibility_library_enable_magic_recs_aggressive_v2_safety_level" + ) + val EnableMagicRecsSafetyLevel: Value = Value( + "visibility_library_enable_magic_recs_safety_level" + ) + val EnableMagicRecsV2SafetyLevel: Value = Value( + "visibility_library_enable_magic_recs_v2_safety_level" + ) + val EnableMinimalSafetyLevel: Value = Value( + "visibility_library_enable_minimal_safety_level" + ) + val EnableModeratedTweetsTimelineSafetyLevel: Value = Value( + "visibility_library_enable_moderated_tweets_timeline_safety_level" + ) + val EnableMomentsSafetyLevel: Value = Value( + "visibility_library_enable_moments_safety_level" + ) + val EnableNearbyTimelineSafetyLevel: Value = Value( + "visibility_library_enable_nearby_timeline_safety_level" + ) + val EnableNewUserExperienceSafetyLevel: Value = Value( + "visibility_library_enable_new_user_experience_safety_level" + ) + val EnableNotificationsIbisSafetyLevel: Value = Value( + "visibility_library_enable_notifications_ibis_safety_level" + ) + val EnableNotificationsPlatformSafetyLevel: Value = Value( + "visibility_library_enable_notifications_platform_safety_level" + ) + val EnableNotificationsPlatformPushSafetyLevel: Value = Value( + "visibility_library_enable_notifications_platform_push_safety_level" + ) + val EnableNotificationsReadSafetyLevel: Value = Value( + "visibility_library_enable_notifications_read_safety_level" + ) + val EnableNotificationsTimelineDeviceFollowSafetyLevel: Value = Value( + "visibility_library_enable_notifications_timeline_device_follow_safety_level" + ) + val EnableNotificationsWriteSafetyLevel: Value = Value( + "visibility_library_enable_notifications_write_safety_level" + ) + val EnableNotificationsWriterV2SafetyLevel: Value = Value( + "visibility_library_enable_notification_writer_v2_safety_level" + ) + val EnableNotificationsWriterTweetHydratorSafetyLevel: Value = Value( + "visibility_library_enable_notifications_writer_tweet_hydrator_safety_level" + ) + val EnableQuickPromoteTweetEligibilitySafetyLevel: Value = Value( + "visibility_library_enable_quick_promote_tweet_eligibility_safety_level" + ) + val EnableQuoteTweetTimelineSafetyLevel: Value = Value( + "visibility_library_enable_quote_tweet_timeline_safety_level" + ) + val EnableQuotedTweetRulesSafetyLevel: Value = Value( + "visibility_library_enable_quoted_tweet_rules_safety_level" + ) + val EnableRecommendationsSafetyLevel: Value = Value( + "visibility_library_enable_recommendations_safety_level" + ) + val EnableRecosVideoSafetyLevel: Value = Value( + "visibility_library_enable_recos_video_safety_level" + ) + val EnableRecosWritePathSafetyLevel: Value = Value( + "visibility_library_enable_recos_write_path_safety_level" + ) + val EnableRepliesGroupingSafetyLevel: Value = Value( + "visibility_library_enable_replies_grouping_safety_level" + ) + val EnableReportCenterSafetyLevel: Value = Value( + "visibility_library_enable_report_center_safety_level" + ) + val EnableReturningUserExperienceSafetyLevel: Value = Value( + "visibility_library_enable_returning_user_experience_safety_level" + ) + val EnableReturningUserExperienceFocalTweetSafetyLevel: Value = Value( + "visibility_library_enable_returning_user_experience_focal_tweet_safety_level" + ) + val EnableRevenueSafetyLevel: Value = Value( + "visibility_library_enable_revenue_safety_level" + ) + val EnableRitoActionedTweetTimelineSafetyLevel: Value = Value( + "visibility_library_enable_rito_actioned_tweet_timeline_safety_level" + ) + val EnableSafeSearchMinimalSafetyLevel: Value = Value( + "visibility_library_enable_safe_search_minimal_safety_level" + ) + val EnableSafeSearchStrictSafetyLevel: Value = Value( + "visibility_library_enable_safe_search_strict_safety_level" + ) + val EnableAccessInternalPromotedContentSafetyLevel: Value = Value( + "visibility_library_enable_access_internal_promoted_content_safety_level" + ) + val EnableSearchHydration: Value = Value( + "visibility_library_enable_search_hydration_safety_level" + ) + val EnableSearchLatest: Value = Value( + "visibility_library_enable_search_latest_safety_level" + ) + val EnableSearchMixerSrpMinimalSafetyLevel: Value = Value( + "visibility_library_enable_search_mixer_srp_minimal_safety_level" + ) + val EnableSearchMixerSrpStrictSafetyLevel: Value = Value( + "visibility_library_enable_search_mixer_srp_strict_safety_level" + ) + val EnableUserSearchSrpSafetyLevel: Value = Value( + "visibility_library_enable_user_search_srp_safety_level" + ) + val EnableUserSearchTypeaheadSafetyLevel: Value = Value( + "visibility_library_enable_user_search_typeahead_safety_level" + ) + val EnableSearchPeopleSrp: Value = Value( + "visibility_library_enable_search_people_srp_safety_level" + ) + val EnableSearchPeopleTypeahead: Value = Value( + "visibility_library_enable_search_people_typeahead_safety_level" + ) + val EnableSearchPhoto: Value = Value( + "visibility_library_enable_search_photo_safety_level" + ) + val EnableSearchTop: Value = Value( + "visibility_library_enable_search_top_safety_level" + ) + val EnableSearchTopQig: Value = Value( + "visibility_library_enable_search_top_qig_safety_level" + ) + val EnableSearchTrendTakeoverPromotedTweet: Value = Value( + "visibility_library_enable_search_trend_takeover_promoted_tweet_safety_level" + ) + val EnableSearchVideo: Value = Value( + "visibility_library_enable_search_video_safety_level" + ) + val EnableSearchLatestUserRules: Value = Value( + "visibility_library_enable_search_latest_user_rules_safety_level" + ) + val EnableShoppingManagerSpyModeSafetyLevel: Value = Value( + "visibility_library_enable_shopping_manager_spy_mode_safety_level" + ) + val EnableSignalsReactions: Value = Value( + "visibility_library_enable_signals_reactions_safety_level" + ) + val EnableSignalsTweetReactingUsers: Value = Value( + "visibility_library_enable_signals_tweet_reacting_users_safety_level" + ) + val EnableSocialProof: Value = Value( + "visibility_library_enable_social_proof_safety_level" + ) + val EnableSoftInterventionPivot: Value = Value( + "visibility_library_enable_soft_intervention_pivot_safety_level" + ) + val EnableSpaceFleetlineSafetyLevel: Value = Value( + "visibility_library_enable_space_fleetline_safety_level" + ) + val EnableSpaceHomeTimelineUprankingSafetyLevel: Value = Value( + "visibility_library_enable_space_home_timeline_upranking_safety_level" + ) + val EnableSpaceJoinScreenSafetyLevel: Value = Value( + "visibility_library_enable_space_join_screen_safety_level" + ) + val EnableSpaceNotificationsSafetyLevel: Value = Value( + "visibility_library_enable_space_notifications_safety_level" + ) + val EnableSpacesSafetyLevel: Value = Value( + "visibility_library_enable_spaces_safety_level" + ) + val EnableSpacesParticipantsSafetyLevel: Value = Value( + "visibility_library_enable_spaces_participants_safety_level" + ) + val EnableSpacesSellerApplicationStatus: Value = Value( + "visibility_library_enable_spaces_seller_application_status_safety_level" + ) + val EnableSpacesSharingSafetyLevel: Value = Value( + "visibility_library_enable_spaces_sharing_safety_level" + ) + val EnableSpaceTweetAvatarHomeTimelineSafetyLevel: Value = Value( + "visibility_library_enable_space_tweet_avatar_home_timeline_safety_level" + ) + val EnableStickersTimelineSafetyLevel: Value = Value( + "visibility_library_enable_stickers_timeline_safety_level" + ) + val EnableStratoExtLimitedEngagementsSafetyLevel: Value = Value( + "visibility_library_enable_strato_ext_limited_engagements_safety_level" + ) + val EnableStreamServicesSafetyLevel: Value = Value( + "visibility_library_enable_stream_services_safety_level" + ) + val EnableTestSafetyLevel: Value = Value( + "visibility_library_enable_test_safety_level" + ) + val EnableTimelineBookmarkSafetyLevel: Value = Value( + "visibility_library_enable_timeline_bookmark_safety_level" + ) + val EnableTimelineContentControlsSafetyLevel: Value = Value( + "visibility_library_enable_timeline_content_controls_safety_level" + ) + val EnableTimelineConversationsSafetyLevel: Value = Value( + "visibility_library_enable_timeline_conversations_safety_level" + ) + val EnableTimelineConversationsDownrankingSafetyLevel: Value = Value( + "visibility_library_enable_timeline_conversations_downranking_safety_level" + ) + val EnableTimelineConversationsDownrankingMinimalSafetyLevel: Value = Value( + "visibility_library_enable_timeline_conversations_downranking_minimal_safety_level" + ) + val EnableTimelineFavoritesSafetyLevel: Value = Value( + "visibility_library_enable_timeline_favorites_safety_level" + ) + val EnableSelfViewTimelineFavoritesSafetyLevel: Value = Value( + "visibility_library_enable_self_view_timeline_favorites_safety_level" + ) + val EnableTweetTimelineFocalTweetSafetyLevel: Value = Value( + "visibility_library_enable_timeline_focal_tweet_safety_level" + ) + val EnableTweetScopedTimelineSafetyLevel: Value = Value( + "visibility_library_enable_tweet_scoped_timeline_safety_level" + ) + val EnableTimelineFollowingActivitySafetyLevel: Value = Value( + "visibility_library_enable_timeline_following_activity_safety_level" + ) + val EnableTimelineHomeSafetyLevel: Value = Value( + "visibility_library_enable_timeline_home_safety_level" + ) + val EnableTimelineHomeCommunitiesSafetyLevel: Value = Value( + "visibility_library_enable_timeline_home_communities_safety_level" + ) + val EnableTimelineHomeHydrationSafetyLevel: Value = Value( + "visibility_library_enable_timeline_home_hydration_safety_level" + ) + val EnableTimelineHomeLatestSafetyLevel: Value = Value( + "visibility_library_enable_timeline_home_latest_safety_level" + ) + val EnableTimelineHomeRecommendationsSafetyLevel: Value = Value( + "visibility_library_enable_timeline_home_recommendations_safety_level" + ) + val EnableTimelineHomeTopicFollowRecommendationsSafetyLevel: Value = Value( + "visibility_library_enable_timeline_home_topic_follow_recommendations_safety_level" + ) + val EnableTimelineScorerSafetyLevel: Value = Value( + "visibility_library_enable_timeline_scorer_safety_level" + ) + val EnableTopicsLandingPageTopicRecommendationsSafetyLevel: Value = Value( + "visibility_library_enable_topics_landing_page_topic_recommendations_safety_level" + ) + val EnableExploreRecommendationsSafetyLevel: Value = Value( + "visibility_library_enable_explore_recommendations_safety_level" + ) + val EnableTimelineInjectionSafetyLevel: Value = Value( + "visibility_library_enable_timeline_injection_safety_level" + ) + val EnableTimelineLikedBySafetyLevel: Value = Value( + "visibility_library_enable_timeline_liked_by_safety_level" + ) + val EnableTimelineListsSafetyLevel: Value = Value( + "visibility_library_enable_timeline_lists_safety_level" + ) + val EnableTimelineMediaSafetyLevel: Value = Value( + "visibility_library_enable_timeline_media_safety_level" + ) + val EnableTimelineMentionsSafetyLevel: Value = Value( + "visibility_library_enable_timeline_mentions_safety_level" + ) + val EnableTimelineModeratedTweetsHydrationSafetyLevel: Value = Value( + "visibility_library_enable_timeline_moderated_tweets_hydration_safety_level" + ) + val EnableTimelineProfileSafetyLevel: Value = Value( + "visibility_library_enable_timeline_profile_safety_level" + ) + val EnableTimelineProfileAllSafetyLevel: Value = Value( + "visibility_library_enable_timeline_profile_all_safety_level" + ) + val EnableTimelineProfileSpacesSafetyLevel: Value = Value( + "visibility_library_enable_timeline_profile_spaces_safety_level") + val EnableTimelineProfileSuperFollowsSafetyLevel: Value = Value( + "visibility_library_enable_timeline_profile_super_follows_safety_level" + ) + val EnableTimelineReactiveBlendingSafetyLevel: Value = Value( + "visibility_library_enable_timeline_reactive_blending_safety_level" + ) + val EnableTimelineRetweetedBySafetyLevel: Value = Value( + "visibility_library_enable_timeline_retweeted_by_safety_level" + ) + val EnableTimelineSuperLikedBySafetyLevel: Value = Value( + "visibility_library_enable_timeline_super_liked_by_safety_level" + ) + val EnableTombstoningSafetyLevel: Value = Value( + "visibility_library_enable_tombstoning_safety_level" + ) + val EnableTopicRecommendationsSafetyLevel: Value = Value( + "visibility_library_enable_topic_recommendations_safety_level" + ) + val EnableTrendsRepresentativeTweetSafetyLevel: Value = Value( + "visibility_library_enable_trends_representative_tweet_safety_level" + ) + val EnableTrustedFriendsUserListSafetyLevel: Value = Value( + "visibility_library_enable_trusted_friends_user_list_safety_level" + ) + val EnableTweetDetailSafetyLevel: Value = Value( + "visibility_library_enable_tweet_detail_safety_level" + ) + val EnableTweetDetailNonTooSafetyLevel: Value = Value( + "visibility_library_enable_tweet_detail_non_too_safety_level" + ) + val EnableTweetDetailWithInjectionsHydrationSafetyLevel: Value = Value( + "visibility_library_enable_tweet_detail_with_injections_hydration_safety_level" + ) + val EnableTweetEngagersSafetyLevel: Value = Value( + "visibility_library_enable_tweet_engagers_safety_level" + ) + val EnableTweetReplyNudgeSafetyLevel: Value = Value( + "visibility_library_enable_tweet_reply_nudge_safety_level" + ) + val EnableTweetWritesApiSafetyLevel: Value = Value( + "visibility_library_enable_tweet_writes_api_safety_level" + ) + val EnableTwitterArticleComposeSafetyLevel: Value = Value( + "visibility_library_enable_twitter_article_compose_safety_level" + ) + val EnableTwitterArticleProfileTabSafetyLevel: Value = Value( + "visibility_library_enable_twitter_article_profile_tab_safety_level" + ) + val EnableTwitterArticleReadSafetyLevel: Value = Value( + "visibility_library_enable_twitter_article_read_safety_level" + ) + val EnableUserProfileHeaderSafetyLevel: Value = Value( + "visibility_library_enable_user_profile_header_safety_level" + ) + val EnableUserMilestoneRecommendationSafetyLevel: Value = Value( + "visibility_library_enable_user_milestone_recommendation_safety_level" + ) + val EnableUserScopedTimelineSafetyLevel: Value = Value( + "visibility_library_enable_user_scoped_timeline_safety_level" + ) + val EnableUserSettingsSafetyLevel: Value = Value( + "visibility_library_enable_user_settings_safety_level" + ) + val EnableVideoAdsSafetyLevel: Value = Value( + "visibility_library_enable_video_ads_safety_level" + ) + val EnableTimelineHomePromotedHydrationSafetyLevel: Value = Value( + "visibility_library_enable_timeline_home_promoted_hydration_safety_level" + ) + + val EnableSuperFollowerConnectionsSafetyLevel: Value = Value( + "visibility_library_enable_super_follower_connnections_safety_level" + ) + + val EnableSuperLikeSafetyLevel: Value = Value( + "visibility_library_enable_super_like_safety_level" + ) + + val EnableZipbirdConsumerArchivesSafetyLevel: Value = Value( + "visibility_library_enable_zipbird_consumer_archives_safety_level" + ) + + val EnableTweetAwardSafetyLevel: Value = Value( + "visibility_library_enable_tweet_award_safety_level" + ) + + val EnableTweetConversationControlRules: Value = Value( + "visibility_library_enable_conversation_control_rules" + ) + val EnableCommunityTweetsControlRules: Value = Value( + "visibility_library_enable_community_tweets_rules" + ) + val EnableDropCommunityTweetWithUndefinedCommunityRule: Value = Value( + "visibility_library_enable_drop_community_tweet_with_undefined_community_rule" + ) + val EnablePSpammyTweetDownrankConvosLowQuality: Value = Value( + "visibility_library_enable_p_spammy_tweet_downrank_convos_low_quality" + ) + val EnableHighPSpammyTweetScoreSearchTweetLabelDropRule: Value = Value( + "visibility_library_enable_high_p_spammy_tweet_score_search_tweet_label_drop_rule" + ) + + val EnableRitoActionedTweetDownrankConvosLowQuality: Value = Value( + "visibility_library_enable_rito_actioned_tweet_downrank_convos_low_quality" + ) + + val EnableNewSensitiveMediaSettingsInterstitialRulesHomeTimeline: Value = Value( + "visibility_library_enable_new_sensitive_media_settings_interstitial_rules_home_timeline") + + val EnableLegacySensitiveMediaRulesHomeTimeline: Value = Value( + "visibility_library_enable_legacy_sensitive_media_rules_home_timeline" + ) + + val EnableNewSensitiveMediaSettingsInterstitialRulesConversation: Value = Value( + "visibility_library_enable_new_sensitive_media_settings_interstitial_rules_conversation" + ) + + val EnableLegacySensitiveMediaRulesConversation: Value = Value( + "visibility_library_enable_legacy_sensitive_media_rules_conversation" + ) + + val EnableNewSensitiveMediaSettingsInterstitialRulesProfileTimeline: Value = Value( + "visibility_library_enable_new_sensitive_media_settings_interstitials_rules_profile_timeline" + ) + + val EnableLegacySensitiveMediaRulesProfileTimeline: Value = Value( + "visibility_library_enable_legacy_sensitive_media_rules_profile_timeline" + ) + + val EnableNewSensitiveMediaSettingsInterstitialRulesTweetDetail: Value = Value( + "visibility_library_enable_new_sensitive_media_settings_interstitials_rules_tweet_detail" + ) + + val EnableLegacySensitiveMediaRulesTweetDetail: Value = Value( + "visibility_library_enable_legacy_sensitive_media_rules_tweet_detail" + ) + + val EnableLegacySensitiveMediaRulesDirectMessages: Value = Value( + "visibility_library_enable_legacy_sensitive_media_rules_direct_messages" + ) + + val VisibilityLibraryEnableToxicReplyFilterConversation: Value = Value( + "visibility_library_enable_toxic_reply_filter_conversation" + ) + + val VisibilityLibraryEnableToxicReplyFilterNotifications: Value = Value( + "visibility_library_enable_toxic_reply_filter_notifications" + ) + + val EnableSmyteSpamTweetRule: Value = Value( + "visibility_library_enable_smyte_spam_tweet_rule" + ) + + val EnableHighSpammyTweetContentScoreSearchLatestTweetLabelDropRule: Value = Value( + "visibility_library_enable_high_spammy_tweet_content_score_search_latest_tweet_label_drop_rule" + ) + val EnableHighSpammyTweetContentScoreSearchTopTweetLabelDropRule: Value = Value( + "visibility_library_enable_high_spammy_tweet_content_score_search_top_tweet_label_drop_rule" + ) + val EnableHighSpammyTweetContentScoreConvoDownrankAbusiveQualityRule: Value = Value( + "visibility_library_enable_high_spammy_tweet_content_score_convo_downrank_abusive_quality_rule" + ) + + val EnableHighCryptospamScoreConvoDownrankAbusiveQualityRule: Value = Value( + "visibility_library_enable_high_cryptospam_score_convo_downrank_abusive_quality_rule" + ) + val EnableHighSpammyTweetContentScoreTrendsTopTweetLabelDropRule: Value = Value( + "visibility_library_enable_high_spammy_tweet_content_score_trends_top_tweet_label_drop_rule" + ) + + val EnableHighSpammyTweetContentScoreTrendsLatestTweetLabelDropRule: Value = Value( + "visibility_library_enable_high_spammy_tweet_content_score_trends_latest_tweet_label_drop_rule" + ) + + val EnableGoreAndViolenceTopicHighRecallTweetLabelRule: Value = Value( + "visibility_library_enable_gore_and_violence_topic_high_recall_tweet_label_rule" + ) + + val EnableLimitRepliesFollowersConversationRule: Value = Value( + "visibility_library_enable_limit_replies_followers_conversation_rule" + ) + + val EnableBlinkBadDownrankingRule: Value = Value( + "visibility_library_enable_blink_bad_downranking_rule" + ) + + val EnableBlinkWorstDownrankingRule: Value = Value( + "visibility_library_enable_blink_worst_downranking_rule" + ) + + val EnableCopypastaSpamDownrankConvosAbusiveQualityRule: Value = Value( + "visibility_library_enable_copypasta_spam_downrank_convos_abusive_quality_rule" + ) + + val EnableCopypastaSpamSearchDropRule: Value = Value( + "visibility_library_enable_copypasta_spam_search_drop_rule" + ) + + val EnableSpammyUserModelHighPrecisionDropTweetRule: Value = Value( + "visibility_library_enable_spammy_user_model_high_precision_drop_tweet_rule" + ) + + val EnableAvoidNsfwRules: Value = Value( + "visibility_library_enable_avoid_nsfw_rules" + ) + + val EnableReportedTweetInterstitialRule: Value = Value( + "visibility_library_enable_reported_tweet_interstitial_rule" + ) + + val EnableReportedTweetInterstitialSearchRule: Value = Value( + "visibility_library_enable_reported_tweet_interstitial_search_rule" + ) + + val EnableDropExclusiveTweetContentRule: Value = Value( + "visibility_library_enable_drop_exclusive_tweet_content_rule" + ) + + val EnableDropExclusiveTweetContentRuleFailClosed: Value = Value( + "visibility_library_enable_drop_exclusive_tweet_content_rule_fail_closed" + ) + + val EnableTombstoneExclusiveQtProfileTimelineParam: Value = Value( + "visibility_library_enable_tombstone_exclusive_quoted_tweet_content_rule" + ) + + val EnableDropAllExclusiveTweetsRule: Value = Value( + "visibility_library_enable_drop_all_exclusive_tweets_rule" + ) + + val EnableDropAllExclusiveTweetsRuleFailClosed: Value = Value( + "visibility_library_enable_drop_all_exclusive_tweets_rule_fail_closed" + ) + + val EnableDownrankSpamReplySectioningRule: Value = Value( + "visibility_library_enable_downrank_spam_reply_sectioning_rule" + ) + + val EnableNsfwTextSectioningRule: Value = Value( + "visibility_library_enable_nsfw_text_sectioning_rule" + ) + + val EnableNsfwAgeBasedMediaDropRules: Value = Value( + "visibility_library_enable_nsfw_age_based_media_drop_rules" + ) + + val EnableNsfwUnderageRules: Value = Value( + "visibility_library_enable_nsfw_underage_rules" + ) + + val EnableNsfwUnderageMediaRules: Value = Value( + "visibility_library_enable_nsfw_underage_media_rules" + ) + + val EnableNsfwMissingAgeRules: Value = Value( + "visibility_library_enable_nsfw_missing_age_rules" + ) + + val EnableNsfwMissingAgeMediaRules: Value = Value( + "visibility_library_enable_nsfw_missing_age_media_rules" + ) + + val EnableSearchIpiSafeSearchWithoutUserInQueryDropRule: Value = Value( + "visibility_library_enable_search_ipi_safe_search_without_user_in_query_drop_rule" + ) + + val EnableTimelineHomePromotedTweetHealthEnforcementRules: Value = Value( + "visibility_library_enable_timeline_home_promoted_tweet_health_enforcement_rules" + ) + + val EnableMutedKeywordFilteringSpaceTitleNotificationsRule: Value = Value( + "visibility_library_enable_muted_keyword_filtering_space_title_notifications_rule" + ) + + val EnableDropTweetsWithGeoRestrictedMediaRule: Value = Value( + "visibility_library_enable_drop_tweets_with_georestricted_media_rule" + ) + val EnableDropAllTrustedFriendsTweetsRule: Value = Value( + "visibility_library_enable_drop_all_trusted_friends_tweets_rule" + ) + + val EnableDropTrustedFriendsTweetContentRule: Value = Value( + "visibility_library_enable_drop_all_trusted_friends_tweet_content_rule" + ) + + val EnableDropCollabInvitationTweetsRule: Value = Value( + "visibility_library_enable_drop_all_collab_invitation_tweets_rule" + ) + + val EnableFetchTweetReportedPerspective: Value = Value( + "visibility_library_enable_fetch_tweet_reported_perspective" + ) + + val EnableFetchTweetMediaMetadata: Value = Value( + "visibility_library_enable_fetch_tweet_media_metadata" + ) + + val VisibilityLibraryEnableFollowCheckInMutedKeyword: Value = Value( + "visibility_library_enable_follow_check_in_mutedkeyword" + ) + + val VisibilityLibraryEnableMediaInterstitialComposition: Value = Value( + "visibility_library_enable_media_interstitial_composition" + ) + + val EnableVerdictScribingFromTweetVisibilityLibrary: Value = Value( + "visibility_library_enable_verdict_scribing_from_tweet_visibility_library" + ) + + val EnableVerdictLoggerEventPublisherInstantiationFromTweetVisibilityLibrary: Value = Value( + "visibility_library_enable_verdict_logger_event_publisher_instantiation_from_tweet_visibility_library" + ) + + val EnableVerdictScribingFromTimelineConversationsVisibilityLibrary: Value = Value( + "visibility_library_enable_verdict_scribing_from_timeline_conversations_visibility_library" + ) + + val EnableVerdictLoggerEventPublisherInstantiationFromTimelineConversationsVisibilityLibrary: Value = + Value( + "visibility_library_enable_verdict_logger_event_publisher_instantiation_from_timeline_conversations_visibility_library" + ) + + val EnableVerdictScribingFromBlenderVisibilityLibrary: Value = Value( + "visibility_library_enable_verdict_scribing_from_blender_visibility_library" + ) + + val EnableVerdictScribingFromSearchVisibilityLibrary: Value = Value( + "visibility_library_enable_verdict_scribing_from_search_visibility_library" + ) + + val EnableVerdictLoggerEventPublisherInstantiationFromBlenderVisibilityLibrary: Value = Value( + "visibility_library_enable_verdict_logger_event_publisher_instantiation_from_blender_visibility_library" + ) + + val EnableVerdictLoggerEventPublisherInstantiationFromSearchVisibilityLibrary: Value = Value( + "visibility_library_enable_verdict_logger_event_publisher_instantiation_from_search_visibility_library" + ) + + val EnableLocalizedTombstoneOnVisibilityResults: Value = Value( + "visibility_library_enable_localized_tombstones_on_visibility_results" + ) + + val EnableShortCircuitingFromTweetVisibilityLibrary: Value = Value( + "visibility_library_enable_short_circuiting_from_tweet_visibility_library" + ) + + val EnableShortCircuitingFromTimelineConversationsVisibilityLibrary: Value = Value( + "visibility_library_enable_short_circuiting_from_timeline_conversations_visibility_library" + ) + + val EnableShortCircuitingFromBlenderVisibilityLibrary: Value = Value( + "visibility_library_enable_short_circuiting_from_blender_visibility_library" + ) + + val EnableShortCircuitingFromSearchVisibilityLibrary: Value = Value( + "visibility_library_enable_short_circuiting_from_search_visibility_library" + ) + + val EnableNsfwTextTopicsDropRule: Value = Value( + "visibility_library_enable_nsfw_text_topics_drop_rule" + ) + + val EnableSpammyTweetRuleVerdictLogging: Value = Value( + "visibility_library_enable_spammy_tweet_rule_verdict_logging" + ) + + val EnableDownlevelRuleVerdictLogging: Value = Value( + "visibility_library_enable_downlevel_rule_verdict_logging" + ) + + val EnableLikelyIvsUserLabelDropRule: Value = Value( + "visibility_library_enable_likely_likely_ivs_user_label_drop_rule" + ) + + val EnableCardVisibilityLibraryCardUriParsing: Value = Value( + "visibility_library_enable_card_visibility_library_card_uri_parsing" + ) + + val EnableCardUriRootDomainDenylistRule: Value = Value( + "visibility_library_enable_card_uri_root_domain_deny_list_rule" + ) + + val EnableCommunityNonMemberPollCardRule: Value = Value( + "visibility_library_enable_community_non_member_poll_card_rule" + ) + + val EnableCommunityNonMemberPollCardRuleFailClosed: Value = Value( + "visibility_library_enable_community_non_member_poll_card_rule_fail_closed" + ) + + val EnableExperimentalNudgeLabelRule: Value = Value( + "visibility_library_enable_experimental_nudge_label_rule" + ) + + val NsfwHighPrecisionUserLabelAvoidTweetRuleEnabledParam: Value = Value( + "visibility_library_nsfw_high_precision_user_label_avoid_tweet_rule_enabled" + ) + + val EnableUserSelfViewOnlySafetyLevel: Value = Value( + "visibility_library_enable_user_self_view_only_safety_level" + ) + + val EnableNewAdAvoidanceRules: Value = Value( + "visibility_library_enable_new_ad_avoidance_rules" + ) + + val EnableNsfaHighRecallAdAvoidanceParam: Value = Value( + "visibility_library_enable_nsfa_high_recall_ad_avoidance_rules" + ) + + val EnableNsfaKeywordsHighPrecisionAdAvoidanceParam: Value = Value( + "visibility_library_enable_nsfa_keywords_high_precision_ad_avoidance_rules" + ) + + val EnableStaleTweetDropRuleParam: Value = Value( + "visibility_library_enable_stale_tweet_drop_rule" + ) + + val EnableStaleTweetDropRuleFailClosedParam: Value = Value( + "visibility_library_enable_stale_tweet_drop_rule_fail_closed" + ) + + val EnableEditHistoryTimelineSafetyLevel: Value = Value( + "visibility_library_enable_edit_history_timeline_safety_level" + ) + + val EnableDeleteStateTweetRules: Value = Value( + "visibility_library_enable_delete_state_tweet_rules" + ) + + val EnableSpacesSharingNsfwDropRulesParam: Value = Value( + "visibility_library_enable_spaces_sharing_nsfw_drop_rule" + ) + + val EnableDropMediaLegalRulesParam: Value = Value( + "visibility_library_enable_drop_media_legal_rules" + ) + + val EnableTombstoneMediaLegalRulesParam: Value = Value( + "visibility_library_enable_tombstone_media_legal_rules" + ) + + val EnableInterstitialMediaLegalRulesParam: Value = Value( + "visibility_library_enable_interstitial_media_legal_rules" + ) + + val EnableViewerIsSoftUserDropRuleParam: Value = Value( + "visibility_library_enable_viewer_is_soft_user_drop_rule" + ) + + val EnableBackendLimitedActions: Value = Value( + "visibility_library_enable_backend_limited_actions" + ) + + val EnableNotificationsQig: Value = Value( + "visibility_library_enable_notifications_qig_safety_level" + ) + + val EnablePdnaQuotedTweetTombstoneRule: Value = Value( + "visibility_library_enable_pdna_quoted_tweet_tombstone_rule" + ) + + val EnableSpamQuotedTweetTombstoneRule: Value = Value( + "visibility_library_enable_spam_quoted_tweet_tombstone_rule" + ) + + val EnableNsfwHpQuotedTweetDropRule: Value = Value( + "visibility_library_enable_nsfw_hp_quoted_tweet_drop_experiment_rule" + ) + val EnableNsfwHpQuotedTweetTombstoneRule: Value = Value( + "visibility_library_enable_nsfw_hp_quoted_tweet_tombstone_experiment_rule" + ) + + val EnableExperimentalRuleEngine: Value = Value( + "visibility_library_enable_experimental_rule_engine" + ) + + val EnableFosnrRules: Value = Value( + "visibility_library_enable_fosnr_rules" + ) + + val EnableInnerQuotedTweetViewerBlocksAuthorInterstitialRule: Value = Value( + "visibility_library_enable_inner_quoted_tweet_viewer_blocks_author_interstitial_rule" + ) + + val EnableInnerQuotedTweetViewerMutesAuthorInterstitialRule: Value = Value( + "visibility_library_enable_inner_quoted_tweet_viewer_mutes_author_interstitial_rule" + ) + + val EnableLocalizedInterstitialGenerator: Value = Value( + "visibility_library_enable_localized_interstitial_generator" + ) + + val EnableProfileMixeMediaSafetyLevel: Value = Value( + "visibility_library_enable_profile_mixer_media_safety_level") + + val EnableConvosLocalizedInterstitial: Value = Value( + "visibility_library_convos_enable_localized_interstitial" + ) + + val EnableConvosLegacyInterstitial: Value = Value( + "visibility_library_convos_enable_legacy_interstitial" + ) + + val DisableLegacyInterstitialFilteredReason: Value = Value( + "visibility_library_disable_legacy_interstitial_filtered_reason" + ) + + val EnableSearchBasicBlockMuteRules: Value = Value( + "visibility_library_enable_search_basic_block_mute_rules" + ) + + val EnableLocalizedInterstitialInUserStateLib: Value = Value( + "visibility_library_enable_localized_interstitial_user_state_lib" + ) + + val EnableProfileMixerFavoritesSafetyLevel: Value = Value( + "visibility_library_enable_profile_mixer_favorites_safety_level") + + val EnableAbusiveBehaviorDropRule: Value = Value( + "visibility_library_enable_abusive_behavior_drop_rule" + ) + + val EnableAbusiveBehaviorInterstitialRule: Value = Value( + "visibility_library_enable_abusive_behavior_interstitial_rule" + ) + + val EnableAbusiveBehaviorLimitedEngagementsRule: Value = Value( + "visibility_library_enable_abusive_behavior_limited_engagements_rule" + ) + + val EnableNotGraduatedDownrankConvosAbusiveQualityRule: Value = Value( + "visibility_library_enable_not_graduated_downrank_convos_abusive_quality_rule" + ) + + val EnableNotGraduatedSearchDropRule: Value = Value( + "visibility_library_enable_not_graduated_search_drop_rule" + ) + + val EnableNotGraduatedDropRule: Value = Value( + "visibility_library_enable_not_graduated_drop_rule" + ) + + val EnableMemoizeSafetyLevelParams: Value = Value( + "visibility_library_enable_memoize_safety_level_params" + ) + + val EnableAuthorBlocksViewerDropRule: Value = Value( + "visibility_library_enable_author_blocks_viewer_drop_rule" + ) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/ExperimentsHelper.scala b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/ExperimentsHelper.scala new file mode 100644 index 000000000..f2240311b --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/ExperimentsHelper.scala @@ -0,0 +1,26 @@ +package com.twitter.visibility.configapi.configs + +import com.twitter.timelines.configapi.Config +import com.twitter.timelines.configapi.ExperimentConfigBuilder +import com.twitter.timelines.configapi.Param +import com.twitter.visibility.configapi.params.VisibilityExperiment +import com.twitter.visibility.models.SafetyLevel + +object ExperimentsHelper { + + def mkABExperimentConfig(experiment: VisibilityExperiment, param: Param[Boolean]): Config = { + ExperimentConfigBuilder(experiment) + .addBucket( + experiment.ControlBucket, + param := true + ) + .addBucket( + experiment.TreatmentBucket, + param := false + ) + .build + } + + def mkABExperimentConfig(experiment: VisibilityExperiment, safetyLevel: SafetyLevel): Config = + mkABExperimentConfig(experiment, safetyLevel.enabledParam) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/VisibilityDeciderGates.scala b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/VisibilityDeciderGates.scala new file mode 100644 index 000000000..7df596ce0 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/VisibilityDeciderGates.scala @@ -0,0 +1,73 @@ +package com.twitter.visibility.configapi.configs + +import com.twitter.decider.Decider +import com.twitter.servo.gate.DeciderGate +import com.twitter.servo.util.Gate + +case class VisibilityDeciderGates(decider: Decider) { + import DeciderKey._ + + private[this] def feature(deciderKey: Value) = decider.feature(deciderKey.toString) + + val enableFetchTweetReportedPerspective: Gate[Unit] = + DeciderGate.linear(feature(DeciderKey.EnableFetchTweetReportedPerspective)) + val enableFetchTweetMediaMetadata: Gate[Unit] = + DeciderGate.linear(feature(DeciderKey.EnableFetchTweetMediaMetadata)) + val enableFollowCheckInMutedKeyword: Gate[Unit] = + DeciderGate.linear(feature(DeciderKey.VisibilityLibraryEnableFollowCheckInMutedKeyword)) + val enableMediaInterstitialComposition: Gate[Unit] = + DeciderGate.linear(feature(DeciderKey.VisibilityLibraryEnableMediaInterstitialComposition)) + val enableExperimentalRuleEngine: Gate[Unit] = + DeciderGate.linear(feature(DeciderKey.EnableExperimentalRuleEngine)) + + val enableLocalizedInterstitialGenerator: Gate[Unit] = + DeciderGate.linear(feature(DeciderKey.EnableLocalizedInterstitialGenerator)) + + val enableShortCircuitingTVL: Gate[Unit] = + DeciderGate.linear(feature(EnableShortCircuitingFromTweetVisibilityLibrary)) + val enableVerdictLoggerTVL = DeciderGate.linear( + feature(DeciderKey.EnableVerdictLoggerEventPublisherInstantiationFromTweetVisibilityLibrary)) + val enableVerdictScribingTVL = + DeciderGate.linear(feature(DeciderKey.EnableVerdictScribingFromTweetVisibilityLibrary)) + val enableBackendLimitedActions = + DeciderGate.linear(feature(DeciderKey.EnableBackendLimitedActions)) + val enableMemoizeSafetyLevelParams: Gate[Unit] = + DeciderGate.linear(feature(DeciderKey.EnableMemoizeSafetyLevelParams)) + + val enableShortCircuitingBVL: Gate[Unit] = + DeciderGate.linear(feature(EnableShortCircuitingFromBlenderVisibilityLibrary)) + val enableVerdictLoggerBVL = DeciderGate.linear( + feature(DeciderKey.EnableVerdictLoggerEventPublisherInstantiationFromBlenderVisibilityLibrary)) + val enableVerdictScribingBVL = + DeciderGate.linear(feature(DeciderKey.EnableVerdictScribingFromBlenderVisibilityLibrary)) + + val enableShortCircuitingSVL: Gate[Unit] = + DeciderGate.linear(feature(EnableShortCircuitingFromSearchVisibilityLibrary)) + val enableVerdictLoggerSVL = DeciderGate.linear( + feature(DeciderKey.EnableVerdictLoggerEventPublisherInstantiationFromSearchVisibilityLibrary)) + val enableVerdictScribingSVL = + DeciderGate.linear(feature(DeciderKey.EnableVerdictScribingFromSearchVisibilityLibrary)) + + val enableShortCircuitingTCVL: Gate[Unit] = + DeciderGate.linear(feature(EnableShortCircuitingFromTimelineConversationsVisibilityLibrary)) + val enableVerdictLoggerTCVL = DeciderGate.linear(feature( + DeciderKey.EnableVerdictLoggerEventPublisherInstantiationFromTimelineConversationsVisibilityLibrary)) + val enableVerdictScribingTCVL = + DeciderGate.linear( + feature(DeciderKey.EnableVerdictScribingFromTimelineConversationsVisibilityLibrary)) + + val enableCardVisibilityLibraryCardUriParsing = + DeciderGate.linear(feature(DeciderKey.EnableCardVisibilityLibraryCardUriParsing)) + + val enableConvosLocalizedInterstitial: Gate[Unit] = + DeciderGate.linear(feature(DeciderKey.EnableConvosLocalizedInterstitial)) + + val enableConvosLegacyInterstitial: Gate[Unit] = + DeciderGate.linear(feature(DeciderKey.EnableConvosLegacyInterstitial)) + + val disableLegacyInterstitialFilteredReason: Gate[Unit] = + DeciderGate.linear(feature(DeciderKey.DisableLegacyInterstitialFilteredReason)) + + val enableLocalizedInterstitialInUserStateLibrary: Gate[Unit] = + DeciderGate.linear(feature(DeciderKey.EnableLocalizedInterstitialInUserStateLib)) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/VisibilityDeciders.scala b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/VisibilityDeciders.scala new file mode 100644 index 000000000..cc78fdb7e --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/VisibilityDeciders.scala @@ -0,0 +1,389 @@ +package com.twitter.visibility.configapi.configs + +import com.twitter.decider.Recipient +import com.twitter.decider.SimpleRecipient +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.logging.Logger +import com.twitter.servo.decider.DeciderGateBuilder +import com.twitter.timelines.configapi.BaseConfigBuilder +import com.twitter.timelines.configapi.BaseRequestContext +import com.twitter.timelines.configapi.Config +import com.twitter.timelines.configapi.Param +import com.twitter.timelines.configapi.WithGuestId +import com.twitter.timelines.configapi.WithUserId +import com.twitter.timelines.configapi.decider.DeciderSwitchOverrideValue +import com.twitter.timelines.configapi.decider.GuestRecipient +import com.twitter.timelines.configapi.decider.RecipientBuilder +import com.twitter.visibility.configapi.params.RuleParams +import com.twitter.visibility.configapi.params.TimelineConversationsDownrankingSpecificParams +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.SafetyLevel._ + +private[visibility] object VisibilityDeciders { + val SafetyLevelToDeciderMap: Map[SafetyLevel, DeciderKey.Value] = Map( + AllSubscribedLists -> DeciderKey.EnableAllSubscribedListsSafetyLevel, + AccessInternalPromotedContent -> DeciderKey.EnableAccessInternalPromotedContentSafetyLevel, + AdsBusinessSettings -> DeciderKey.EnableAdsBusinessSettingsSafetyLevel, + AdsCampaign -> DeciderKey.EnableAdsCampaignSafetyLevel, + AdsManager -> DeciderKey.EnableAdsManagerSafetyLevel, + AdsReportingDashboard -> DeciderKey.EnableAdsReportingDashboardSafetyLevel, + Appeals -> DeciderKey.EnableAppealsSafetyLevel, + ArticleTweetTimeline -> DeciderKey.EnableArticleTweetTimelineSafetyLevel, + BaseQig -> DeciderKey.EnableBaseQig, + BirdwatchNoteAuthor -> DeciderKey.EnableBirdwatchNoteAuthorSafetyLevel, + BirdwatchNoteTweetsTimeline -> DeciderKey.EnableBirdwatchNoteTweetsTimelineSafetyLevel, + BirdwatchNeedsYourHelpNotifications -> DeciderKey.EnableBirdwatchNeedsYourHelpNotificationsSafetyLevel, + BlockMuteUsersTimeline -> DeciderKey.EnableBlockMuteUsersTimelineSafetyLevel, + BrandSafety -> DeciderKey.EnableBrandSafetySafetyLevel, + CardPollVoting -> DeciderKey.EnableCardPollVotingSafetyLevel, + CardsService -> DeciderKey.EnableCardsServiceSafetyLevel, + Communities -> DeciderKey.EnableCommunitiesSafetyLevel, + ContentControlToolInstall -> DeciderKey.EnableContentControlToolInstallSafetyLevel, + ConversationFocalPrehydration -> DeciderKey.EnableConversationFocalPrehydrationSafetyLevel, + ConversationFocalTweet -> DeciderKey.EnableConversationFocalTweetSafetyLevel, + ConversationInjectedTweet -> DeciderKey.EnableConversationInjectedTweetSafetyLevel, + ConversationReply -> DeciderKey.EnableConversationReplySafetyLevel, + CuratedTrendsRepresentativeTweet -> DeciderKey.EnableCuratedTrendsRepresentativeTweet, + CurationPolicyViolations -> DeciderKey.EnableCurationPolicyViolations, + DeprecatedSafetyLevel -> DeciderKey.EnableDeprecatedSafetyLevelSafetyLevel, + DevPlatformGetListTweets -> DeciderKey.EnableDevPlatformGetListTweetsSafetyLevel, + DesFollowingAndFollowersUserList -> DeciderKey.EnableDesFollowingAndFollowersUserListSafetyLevel, + DesHomeTimeline -> DeciderKey.EnableDesHomeTimelineSafetyLevel, + DesQuoteTweetTimeline -> DeciderKey.EnableDesQuoteTweetTimelineSafetyLevel, + DesRealtime -> DeciderKey.EnableDesRealtimeSafetyLevel, + DesRealtimeSpamEnrichment -> DeciderKey.EnableDesRealtimeSpamEnrichmentSafetyLevel, + DesRealtimeTweetFilter -> DeciderKey.EnableDesRealtimeTweetFilterSafetyLevel, + DesRetweetingUsers -> DeciderKey.EnableDesRetweetingUsersSafetyLevel, + DesTweetDetail -> DeciderKey.EnableDesTweetDetailSafetyLevel, + DesTweetLikingUsers -> DeciderKey.EnableDesTweetLikingUsersSafetyLevel, + DesUserBookmarks -> DeciderKey.EnableDesUserBookmarksSafetyLevel, + DesUserLikedTweets -> DeciderKey.EnableDesUserLikedTweetsSafetyLevel, + DesUserMentions -> DeciderKey.EnableDesUserMentionsSafetyLevel, + DesUserTweets -> DeciderKey.EnableDesUserTweetsSafetyLevel, + DevPlatformComplianceStream -> DeciderKey.EnableDevPlatformComplianceStreamSafetyLevel, + DirectMessages -> DeciderKey.EnableDirectMessagesSafetyLevel, + DirectMessagesConversationList -> DeciderKey.EnableDirectMessagesConversationListSafetyLevel, + DirectMessagesConversationTimeline -> DeciderKey.EnableDirectMessagesConversationTimelineSafetyLevel, + DirectMessagesInbox -> DeciderKey.EnableDirectMessagesInboxSafetyLevel, + DirectMessagesMutedUsers -> DeciderKey.EnableDirectMessagesMutedUsersSafetyLevel, + DirectMessagesPinned -> DeciderKey.EnableDirectMessagesPinnedSafetyLevel, + DirectMessagesSearch -> DeciderKey.EnableDirectMessagesSearchSafetyLevel, + EditHistoryTimeline -> DeciderKey.EnableEditHistoryTimelineSafetyLevel, + ElevatedQuoteTweetTimeline -> DeciderKey.EnableElevatedQuoteTweetTimelineSafetyLevel, + EmbeddedTweet -> DeciderKey.EnableEmbeddedTweetSafetyLevel, + EmbedsPublicInterestNotice -> DeciderKey.EnableEmbedsPublicInterestNoticeSafetyLevel, + EmbedTweetMarkup -> DeciderKey.EnableEmbedTweetMarkupSafetyLevel, + FilterAll -> DeciderKey.EnableFilterAllSafetyLevel, + FilterAllPlaceholder -> DeciderKey.EnableFilterAllPlaceholderSafetyLevel, + FilterNone -> DeciderKey.EnableFilterNoneSafetyLevel, + FilterDefault -> DeciderKey.EnableFilterDefaultSafetyLevel, + FollowedTopicsTimeline -> DeciderKey.EnableFollowedTopicsTimelineSafetyLevel, + FollowerConnections -> DeciderKey.EnableFollowerConnectionsSafetyLevel, + FollowingAndFollowersUserList -> DeciderKey.EnableFollowingAndFollowersUserListSafetyLevel, + ForDevelopmentOnly -> DeciderKey.EnableForDevelopmentOnlySafetyLevel, + FriendsFollowingList -> DeciderKey.EnableFriendsFollowingListSafetyLevel, + GraphqlDefault -> DeciderKey.EnableGraphqlDefaultSafetyLevel, + GryphonDecksAndColumns -> DeciderKey.EnableGryphonDecksAndColumnsSafetyLevel, + HumanizationNudge -> DeciderKey.EnableHumanizationNudgeSafetyLevel, + KitchenSinkDevelopment -> DeciderKey.EnableKitchenSinkDevelopmentSafetyLevel, + ListHeader -> DeciderKey.EnableListHeaderSafetyLevel, + ListMemberships -> DeciderKey.EnableListMembershipsSafetyLevel, + ListOwnerships -> DeciderKey.EnableListOwnershipsSafetyLevel, + ListRecommendations -> DeciderKey.EnableListRecommendationsSafetyLevel, + ListSearch -> DeciderKey.EnableListSearchSafetyLevel, + ListSubscriptions -> DeciderKey.EnableListSubscriptionsSafetyLevel, + LiveVideoTimeline -> DeciderKey.EnableLiveVideoTimelineSafetyLevel, + LivePipelineEngagementCounts -> DeciderKey.EnableLivePipelineEngagementCountsSafetyLevel, + MagicRecs -> DeciderKey.EnableMagicRecsSafetyLevel, + MagicRecsAggressive -> DeciderKey.EnableMagicRecsAggressiveSafetyLevel, + MagicRecsAggressiveV2 -> DeciderKey.EnableMagicRecsAggressiveV2SafetyLevel, + MagicRecsV2 -> DeciderKey.EnableMagicRecsV2SafetyLevel, + Minimal -> DeciderKey.EnableMinimalSafetyLevel, + ModeratedTweetsTimeline -> DeciderKey.EnableModeratedTweetsTimelineSafetyLevel, + Moments -> DeciderKey.EnableMomentsSafetyLevel, + NearbyTimeline -> DeciderKey.EnableNearbyTimelineSafetyLevel, + NewUserExperience -> DeciderKey.EnableNewUserExperienceSafetyLevel, + NotificationsIbis -> DeciderKey.EnableNotificationsIbisSafetyLevel, + NotificationsPlatform -> DeciderKey.EnableNotificationsPlatformSafetyLevel, + NotificationsPlatformPush -> DeciderKey.EnableNotificationsPlatformPushSafetyLevel, + NotificationsQig -> DeciderKey.EnableNotificationsQig, + NotificationsRead -> DeciderKey.EnableNotificationsReadSafetyLevel, + NotificationsTimelineDeviceFollow -> DeciderKey.EnableNotificationsTimelineDeviceFollowSafetyLevel, + NotificationsWrite -> DeciderKey.EnableNotificationsWriteSafetyLevel, + NotificationsWriterV2 -> DeciderKey.EnableNotificationsWriterV2SafetyLevel, + NotificationsWriterTweetHydrator -> DeciderKey.EnableNotificationsWriterTweetHydratorSafetyLevel, + ProfileMixerMedia -> DeciderKey.EnableProfileMixeMediaSafetyLevel, + ProfileMixerFavorites -> DeciderKey.EnableProfileMixerFavoritesSafetyLevel, + QuickPromoteTweetEligibility -> DeciderKey.EnableQuickPromoteTweetEligibilitySafetyLevel, + QuoteTweetTimeline -> DeciderKey.EnableQuoteTweetTimelineSafetyLevel, + QuotedTweetRules -> DeciderKey.EnableQuotedTweetRulesSafetyLevel, + Recommendations -> DeciderKey.EnableRecommendationsSafetyLevel, + RecosVideo -> DeciderKey.EnableRecosVideoSafetyLevel, + RecosWritePath -> DeciderKey.EnableRecosWritePathSafetyLevel, + RepliesGrouping -> DeciderKey.EnableRepliesGroupingSafetyLevel, + ReportCenter -> DeciderKey.EnableReportCenterSafetyLevel, + ReturningUserExperience -> DeciderKey.EnableReturningUserExperienceSafetyLevel, + ReturningUserExperienceFocalTweet -> DeciderKey.EnableReturningUserExperienceFocalTweetSafetyLevel, + Revenue -> DeciderKey.EnableRevenueSafetyLevel, + RitoActionedTweetTimeline -> DeciderKey.EnableRitoActionedTweetTimelineSafetyLevel, + SafeSearchMinimal -> DeciderKey.EnableSafeSearchMinimalSafetyLevel, + SafeSearchStrict -> DeciderKey.EnableSafeSearchStrictSafetyLevel, + SearchMixerSrpMinimal -> DeciderKey.EnableSearchMixerSrpMinimalSafetyLevel, + SearchMixerSrpStrict -> DeciderKey.EnableSearchMixerSrpStrictSafetyLevel, + SearchHydration -> DeciderKey.EnableSearchHydration, + SearchLatest -> DeciderKey.EnableSearchLatest, + SearchPeopleSrp -> DeciderKey.EnableSearchPeopleSrp, + SearchPeopleTypeahead -> DeciderKey.EnableSearchPeopleTypeahead, + SearchPhoto -> DeciderKey.EnableSearchPhoto, + SearchTrendTakeoverPromotedTweet -> DeciderKey.EnableSearchTrendTakeoverPromotedTweet, + SearchTop -> DeciderKey.EnableSearchTop, + SearchTopQig -> DeciderKey.EnableSearchTopQig, + SearchVideo -> DeciderKey.EnableSearchVideo, + SearchBlenderUserRules -> DeciderKey.EnableSearchLatestUserRules, + SearchLatestUserRules -> DeciderKey.EnableSearchLatestUserRules, + ShoppingManagerSpyMode -> DeciderKey.EnableShoppingManagerSpyModeSafetyLevel, + SignalsReactions -> DeciderKey.EnableSignalsReactions, + SignalsTweetReactingUsers -> DeciderKey.EnableSignalsTweetReactingUsers, + SocialProof -> DeciderKey.EnableSocialProof, + SoftInterventionPivot -> DeciderKey.EnableSoftInterventionPivot, + SpaceFleetline -> DeciderKey.EnableSpaceFleetlineSafetyLevel, + SpaceHomeTimelineUpranking -> DeciderKey.EnableSpaceHomeTimelineUprankingSafetyLevel, + SpaceJoinScreen -> DeciderKey.EnableSpaceJoinScreenSafetyLevel, + SpaceNotifications -> DeciderKey.EnableSpaceNotificationsSafetyLevel, + Spaces -> DeciderKey.EnableSpacesSafetyLevel, + SpacesParticipants -> DeciderKey.EnableSpacesParticipantsSafetyLevel, + SpacesSellerApplicationStatus -> DeciderKey.EnableSpacesSellerApplicationStatus, + SpacesSharing -> DeciderKey.EnableSpacesSharingSafetyLevel, + SpaceTweetAvatarHomeTimeline -> DeciderKey.EnableSpaceTweetAvatarHomeTimelineSafetyLevel, + StickersTimeline -> DeciderKey.EnableStickersTimelineSafetyLevel, + StratoExtLimitedEngagements -> DeciderKey.EnableStratoExtLimitedEngagementsSafetyLevel, + StreamServices -> DeciderKey.EnableStreamServicesSafetyLevel, + SuperFollowerConnections -> DeciderKey.EnableSuperFollowerConnectionsSafetyLevel, + SuperLike -> DeciderKey.EnableSuperLikeSafetyLevel, + Test -> DeciderKey.EnableTestSafetyLevel, + TimelineContentControls -> DeciderKey.EnableTimelineContentControlsSafetyLevel, + TimelineConversations -> DeciderKey.EnableTimelineConversationsSafetyLevel, + TimelineConversationsDownranking -> DeciderKey.EnableTimelineConversationsDownrankingSafetyLevel, + TimelineConversationsDownrankingMinimal -> DeciderKey.EnableTimelineConversationsDownrankingMinimalSafetyLevel, + TimelineFollowingActivity -> DeciderKey.EnableTimelineFollowingActivitySafetyLevel, + TimelineHome -> DeciderKey.EnableTimelineHomeSafetyLevel, + TimelineHomeCommunities -> DeciderKey.EnableTimelineHomeCommunitiesSafetyLevel, + TimelineHomeHydration -> DeciderKey.EnableTimelineHomeHydrationSafetyLevel, + TimelineHomePromotedHydration -> DeciderKey.EnableTimelineHomePromotedHydrationSafetyLevel, + TimelineHomeRecommendations -> DeciderKey.EnableTimelineHomeRecommendationsSafetyLevel, + TimelineHomeTopicFollowRecommendations -> DeciderKey.EnableTimelineHomeTopicFollowRecommendationsSafetyLevel, + TimelineScorer -> DeciderKey.EnableTimelineScorerSafetyLevel, + TopicsLandingPageTopicRecommendations -> DeciderKey.EnableTopicsLandingPageTopicRecommendationsSafetyLevel, + ExploreRecommendations -> DeciderKey.EnableExploreRecommendationsSafetyLevel, + TimelineInjection -> DeciderKey.EnableTimelineInjectionSafetyLevel, + TimelineMentions -> DeciderKey.EnableTimelineMentionsSafetyLevel, + TimelineModeratedTweetsHydration -> DeciderKey.EnableTimelineModeratedTweetsHydrationSafetyLevel, + TimelineHomeLatest -> DeciderKey.EnableTimelineHomeLatestSafetyLevel, + TimelineLikedBy -> DeciderKey.EnableTimelineLikedBySafetyLevel, + TimelineRetweetedBy -> DeciderKey.EnableTimelineRetweetedBySafetyLevel, + TimelineSuperLikedBy -> DeciderKey.EnableTimelineSuperLikedBySafetyLevel, + TimelineBookmark -> DeciderKey.EnableTimelineBookmarkSafetyLevel, + TimelineMedia -> DeciderKey.EnableTimelineMediaSafetyLevel, + TimelineReactiveBlending -> DeciderKey.EnableTimelineReactiveBlendingSafetyLevel, + TimelineFavorites -> DeciderKey.EnableTimelineFavoritesSafetyLevel, + TimelineFavoritesSelfView -> DeciderKey.EnableSelfViewTimelineFavoritesSafetyLevel, + TimelineLists -> DeciderKey.EnableTimelineListsSafetyLevel, + TimelineProfile -> DeciderKey.EnableTimelineProfileSafetyLevel, + TimelineProfileAll -> DeciderKey.EnableTimelineProfileAllSafetyLevel, + TimelineProfileSpaces -> DeciderKey.EnableTimelineProfileSpacesSafetyLevel, + TimelineProfileSuperFollows -> DeciderKey.EnableTimelineProfileSuperFollowsSafetyLevel, + TimelineFocalTweet -> DeciderKey.EnableTweetTimelineFocalTweetSafetyLevel, + TweetDetailWithInjectionsHydration -> DeciderKey.EnableTweetDetailWithInjectionsHydrationSafetyLevel, + Tombstoning -> DeciderKey.EnableTombstoningSafetyLevel, + TopicRecommendations -> DeciderKey.EnableTopicRecommendationsSafetyLevel, + TrendsRepresentativeTweet -> DeciderKey.EnableTrendsRepresentativeTweetSafetyLevel, + TrustedFriendsUserList -> DeciderKey.EnableTrustedFriendsUserListSafetyLevel, + TweetDetail -> DeciderKey.EnableTweetDetailSafetyLevel, + TweetDetailNonToo -> DeciderKey.EnableTweetDetailNonTooSafetyLevel, + TweetEngagers -> DeciderKey.EnableTweetEngagersSafetyLevel, + TweetReplyNudge -> DeciderKey.EnableTweetReplyNudgeSafetyLevel, + TweetScopedTimeline -> DeciderKey.EnableTweetScopedTimelineSafetyLevel, + TweetWritesApi -> DeciderKey.EnableTweetWritesApiSafetyLevel, + TwitterArticleCompose -> DeciderKey.EnableTwitterArticleComposeSafetyLevel, + TwitterArticleProfileTab -> DeciderKey.EnableTwitterArticleProfileTabSafetyLevel, + TwitterArticleRead -> DeciderKey.EnableTwitterArticleReadSafetyLevel, + UserProfileHeader -> DeciderKey.EnableUserProfileHeaderSafetyLevel, + UserMilestoneRecommendation -> DeciderKey.EnableUserMilestoneRecommendationSafetyLevel, + UserScopedTimeline -> DeciderKey.EnableUserScopedTimelineSafetyLevel, + UserSearchSrp -> DeciderKey.EnableUserSearchSrpSafetyLevel, + UserSearchTypeahead -> DeciderKey.EnableUserSearchTypeaheadSafetyLevel, + UserSelfViewOnly -> DeciderKey.EnableUserSelfViewOnlySafetyLevel, + UserSettings -> DeciderKey.EnableUserSettingsSafetyLevel, + VideoAds -> DeciderKey.EnableVideoAdsSafetyLevel, + WritePathLimitedActionsEnforcement -> DeciderKey.EnableWritePathLimitedActionsEnforcementSafetyLevel, + ZipbirdConsumerArchives -> DeciderKey.EnableZipbirdConsumerArchivesSafetyLevel, + TweetAward -> DeciderKey.EnableTweetAwardSafetyLevel, + ) + + val BoolToDeciderMap: Map[Param[Boolean], DeciderKey.Value] = Map( + RuleParams.TweetConversationControlEnabledParam -> + DeciderKey.EnableTweetConversationControlRules, + RuleParams.CommunityTweetsEnabledParam -> + DeciderKey.EnableCommunityTweetsControlRules, + RuleParams.DropCommunityTweetWithUndefinedCommunityRuleEnabledParam -> + DeciderKey.EnableDropCommunityTweetWithUndefinedCommunityRule, + TimelineConversationsDownrankingSpecificParams.EnablePSpammyTweetDownrankConvosLowQualityParam -> + DeciderKey.EnablePSpammyTweetDownrankConvosLowQuality, + RuleParams.EnableHighPSpammyTweetScoreSearchTweetLabelDropRuleParam -> + DeciderKey.EnableHighPSpammyTweetScoreSearchTweetLabelDropRule, + TimelineConversationsDownrankingSpecificParams.EnableRitoActionedTweetDownrankConvosLowQualityParam -> + DeciderKey.EnableRitoActionedTweetDownrankConvosLowQuality, + RuleParams.EnableSmyteSpamTweetRuleParam -> + DeciderKey.EnableSmyteSpamTweetRule, + RuleParams.EnableHighSpammyTweetContentScoreSearchLatestTweetLabelDropRuleParam -> + DeciderKey.EnableHighSpammyTweetContentScoreSearchLatestTweetLabelDropRule, + RuleParams.EnableHighSpammyTweetContentScoreSearchTopTweetLabelDropRuleParam -> + DeciderKey.EnableHighSpammyTweetContentScoreSearchTopTweetLabelDropRule, + RuleParams.EnableHighSpammyTweetContentScoreTrendsTopTweetLabelDropRuleParam -> + DeciderKey.EnableHighSpammyTweetContentScoreTrendsTopTweetLabelDropRule, + RuleParams.EnableHighSpammyTweetContentScoreTrendsLatestTweetLabelDropRuleParam -> + DeciderKey.EnableHighSpammyTweetContentScoreTrendsLatestTweetLabelDropRule, + TimelineConversationsDownrankingSpecificParams.EnableHighSpammyTweetContentScoreConvoDownrankAbusiveQualityRuleParam -> + DeciderKey.EnableHighSpammyTweetContentScoreConvoDownrankAbusiveQualityRule, + TimelineConversationsDownrankingSpecificParams.EnableHighCryptospamScoreConvoDownrankAbusiveQualityRuleParam -> + DeciderKey.EnableHighCryptospamScoreConvoDownrankAbusiveQualityRule, + RuleParams.EnableGoreAndViolenceTopicHighRecallTweetLabelRule -> + DeciderKey.EnableGoreAndViolenceTopicHighRecallTweetLabelRule, + RuleParams.EnableLimitRepliesFollowersConversationRule -> + DeciderKey.EnableLimitRepliesFollowersConversationRule, + RuleParams.EnableSearchBasicBlockMuteRulesParam -> DeciderKey.EnableSearchBasicBlockMuteRules, + RuleParams.EnableBlinkBadDownrankingRuleParam -> + DeciderKey.EnableBlinkBadDownrankingRule, + RuleParams.EnableBlinkWorstDownrankingRuleParam -> + DeciderKey.EnableBlinkWorstDownrankingRule, + RuleParams.EnableCopypastaSpamDownrankConvosAbusiveQualityRule -> + DeciderKey.EnableCopypastaSpamDownrankConvosAbusiveQualityRule, + RuleParams.EnableCopypastaSpamSearchDropRule -> + DeciderKey.EnableCopypastaSpamSearchDropRule, + RuleParams.EnableSpammyUserModelTweetDropRuleParam -> + DeciderKey.EnableSpammyUserModelHighPrecisionDropTweetRule, + RuleParams.EnableAvoidNsfwRulesParam -> + DeciderKey.EnableAvoidNsfwRules, + RuleParams.EnableReportedTweetInterstitialRule -> + DeciderKey.EnableReportedTweetInterstitialRule, + RuleParams.EnableReportedTweetInterstitialSearchRule -> + DeciderKey.EnableReportedTweetInterstitialSearchRule, + RuleParams.EnableDropExclusiveTweetContentRule -> + DeciderKey.EnableDropExclusiveTweetContentRule, + RuleParams.EnableDropExclusiveTweetContentRuleFailClosed -> + DeciderKey.EnableDropExclusiveTweetContentRuleFailClosed, + RuleParams.EnableTombstoneExclusiveQtProfileTimelineParam -> + DeciderKey.EnableTombstoneExclusiveQtProfileTimelineParam, + RuleParams.EnableDropAllExclusiveTweetsRuleParam -> DeciderKey.EnableDropAllExclusiveTweetsRule, + RuleParams.EnableDropAllExclusiveTweetsRuleFailClosedParam -> DeciderKey.EnableDropAllExclusiveTweetsRuleFailClosed, + RuleParams.EnableDownrankSpamReplySectioningRuleParam -> + DeciderKey.EnableDownrankSpamReplySectioningRule, + RuleParams.EnableNsfwTextSectioningRuleParam -> + DeciderKey.EnableNsfwTextSectioningRule, + RuleParams.EnableSearchIpiSafeSearchWithoutUserInQueryDropRule -> DeciderKey.EnableSearchIpiSafeSearchWithoutUserInQueryDropRule, + RuleParams.EnableTimelineHomePromotedTweetHealthEnforcementRules -> DeciderKey.EnableTimelineHomePromotedTweetHealthEnforcementRules, + RuleParams.EnableMutedKeywordFilteringSpaceTitleNotificationsRuleParam -> DeciderKey.EnableMutedKeywordFilteringSpaceTitleNotificationsRule, + RuleParams.EnableDropTweetsWithGeoRestrictedMediaRuleParam -> DeciderKey.EnableDropTweetsWithGeoRestrictedMediaRule, + RuleParams.EnableDropAllTrustedFriendsTweetsRuleParam -> DeciderKey.EnableDropAllTrustedFriendsTweetsRule, + RuleParams.EnableDropTrustedFriendsTweetContentRuleParam -> DeciderKey.EnableDropTrustedFriendsTweetContentRule, + RuleParams.EnableDropAllCollabInvitationTweetsRuleParam -> DeciderKey.EnableDropCollabInvitationTweetsRule, + RuleParams.EnableNsfwTextTopicsDropRuleParam -> DeciderKey.EnableNsfwTextTopicsDropRule, + RuleParams.EnableLikelyIvsUserLabelDropRule -> DeciderKey.EnableLikelyIvsUserLabelDropRule, + RuleParams.EnableCardUriRootDomainCardDenylistRule -> DeciderKey.EnableCardUriRootDomainDenylistRule, + RuleParams.EnableCommunityNonMemberPollCardRule -> DeciderKey.EnableCommunityNonMemberPollCardRule, + RuleParams.EnableCommunityNonMemberPollCardRuleFailClosed -> DeciderKey.EnableCommunityNonMemberPollCardRuleFailClosed, + RuleParams.EnableExperimentalNudgeEnabledParam -> DeciderKey.EnableExperimentalNudgeLabelRule, + RuleParams.NsfwHighPrecisionUserLabelAvoidTweetRuleEnabledParam -> DeciderKey.NsfwHighPrecisionUserLabelAvoidTweetRuleEnabledParam, + RuleParams.EnableNewAdAvoidanceRulesParam -> DeciderKey.EnableNewAdAvoidanceRules, + RuleParams.EnableNsfaHighRecallAdAvoidanceParam -> DeciderKey.EnableNsfaHighRecallAdAvoidanceParam, + RuleParams.EnableNsfaKeywordsHighPrecisionAdAvoidanceParam -> DeciderKey.EnableNsfaKeywordsHighPrecisionAdAvoidanceParam, + RuleParams.EnableStaleTweetDropRuleParam -> DeciderKey.EnableStaleTweetDropRuleParam, + RuleParams.EnableStaleTweetDropRuleFailClosedParam -> DeciderKey.EnableStaleTweetDropRuleFailClosedParam, + RuleParams.EnableDeleteStateTweetRulesParam -> DeciderKey.EnableDeleteStateTweetRules, + RuleParams.EnableSpacesSharingNsfwDropRulesParam -> DeciderKey.EnableSpacesSharingNsfwDropRulesParam, + RuleParams.EnableViewerIsSoftUserDropRuleParam -> DeciderKey.EnableViewerIsSoftUserDropRuleParam, + RuleParams.EnablePdnaQuotedTweetTombstoneRuleParam -> DeciderKey.EnablePdnaQuotedTweetTombstoneRule, + RuleParams.EnableSpamQuotedTweetTombstoneRuleParam -> DeciderKey.EnableSpamQuotedTweetTombstoneRule, + RuleParams.EnableNsfwHpQuotedTweetDropRuleParam -> DeciderKey.EnableNsfwHpQuotedTweetDropRule, + RuleParams.EnableNsfwHpQuotedTweetTombstoneRuleParam -> DeciderKey.EnableNsfwHpQuotedTweetTombstoneRule, + RuleParams.EnableInnerQuotedTweetViewerBlocksAuthorInterstitialRuleParam -> DeciderKey.EnableInnerQuotedTweetViewerBlocksAuthorInterstitialRule, + RuleParams.EnableInnerQuotedTweetViewerMutesAuthorInterstitialRuleParam -> DeciderKey.EnableInnerQuotedTweetViewerMutesAuthorInterstitialRule, + RuleParams.EnableToxicReplyFilteringConversationRulesParam -> DeciderKey.VisibilityLibraryEnableToxicReplyFilterConversation, + RuleParams.EnableToxicReplyFilteringNotificationsRulesParam -> DeciderKey.VisibilityLibraryEnableToxicReplyFilterNotifications, + RuleParams.EnableLegacySensitiveMediaHomeTimelineRulesParam -> DeciderKey.EnableLegacySensitiveMediaRulesHomeTimeline, + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsHomeTimelineRulesParam -> DeciderKey.EnableNewSensitiveMediaSettingsInterstitialRulesHomeTimeline, + RuleParams.EnableLegacySensitiveMediaConversationRulesParam -> DeciderKey.EnableLegacySensitiveMediaRulesConversation, + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsConversationRulesParam -> DeciderKey.EnableNewSensitiveMediaSettingsInterstitialRulesConversation, + RuleParams.EnableLegacySensitiveMediaProfileTimelineRulesParam -> DeciderKey.EnableLegacySensitiveMediaRulesProfileTimeline, + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsProfileTimelineRulesParam -> DeciderKey.EnableNewSensitiveMediaSettingsInterstitialRulesProfileTimeline, + RuleParams.EnableLegacySensitiveMediaTweetDetailRulesParam -> DeciderKey.EnableLegacySensitiveMediaRulesTweetDetail, + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsTweetDetailRulesParam -> DeciderKey.EnableNewSensitiveMediaSettingsInterstitialRulesTweetDetail, + RuleParams.EnableLegacySensitiveMediaDirectMessagesRulesParam -> DeciderKey.EnableLegacySensitiveMediaRulesDirectMessages, + RuleParams.EnableAbusiveBehaviorDropRuleParam -> DeciderKey.EnableAbusiveBehaviorDropRule, + RuleParams.EnableAbusiveBehaviorInterstitialRuleParam -> DeciderKey.EnableAbusiveBehaviorInterstitialRule, + RuleParams.EnableAbusiveBehaviorLimitedEngagementsRuleParam -> DeciderKey.EnableAbusiveBehaviorLimitedEngagementsRule, + RuleParams.EnableNotGraduatedDownrankConvosAbusiveQualityRuleParam -> DeciderKey.EnableNotGraduatedDownrankConvosAbusiveQualityRule, + RuleParams.EnableNotGraduatedSearchDropRuleParam -> DeciderKey.EnableNotGraduatedSearchDropRule, + RuleParams.EnableNotGraduatedDropRuleParam -> DeciderKey.EnableNotGraduatedDropRule, + RuleParams.EnableFosnrRuleParam -> DeciderKey.EnableFosnrRules, + RuleParams.EnableAuthorBlocksViewerDropRuleParam -> DeciderKey.EnableAuthorBlocksViewerDropRule + ) + + def config( + deciderGateBuilder: DeciderGateBuilder, + logger: Logger, + statsReceiver: StatsReceiver, + SafetyLevel: SafetyLevel + ): Config = { + + object UserOrGuestOrRequest extends RecipientBuilder { + private val scopedStats = statsReceiver.scope("decider_recipient") + private val userIdDefinedCounter = scopedStats.counter("user_id_defined") + private val userIdNotDefinedCounter = scopedStats.counter("user_id_undefined") + private val guestIdDefinedCounter = scopedStats.counter("guest_id_defined") + private val guestIdNotDefinedCounter = scopedStats.counter("guest_id_undefined") + private val noIdCounter = scopedStats.counter("no_id_defined") + + def apply(requestContext: BaseRequestContext): Option[Recipient] = requestContext match { + case c: WithUserId if c.userId.isDefined => + userIdDefinedCounter.incr() + c.userId.map(SimpleRecipient) + case c: WithGuestId if c.guestId.isDefined => + guestIdDefinedCounter.incr() + c.guestId.map(GuestRecipient) + case c: WithGuestId => + guestIdNotDefinedCounter.incr() + RecipientBuilder.Request(c) + case _: WithUserId => + userIdNotDefinedCounter.incr() + None + case _ => + logger.warning("Request Context with no user or guest id trait found: " + requestContext) + noIdCounter.incr() + None + } + } + + val boolOverrides = BoolToDeciderMap.map { + case (param, deciderKey) => + param.optionalOverrideValue( + DeciderSwitchOverrideValue( + feature = deciderGateBuilder.keyToFeature(deciderKey), + enabledValue = true, + disabledValueOption = Some(false), + recipientBuilder = UserOrGuestOrRequest + ) + ) + }.toSeq + + val safetyLevelOverride = SafetyLevel.enabledParam.optionalOverrideValue( + DeciderSwitchOverrideValue( + feature = deciderGateBuilder.keyToFeature(SafetyLevelToDeciderMap(SafetyLevel)), + enabledValue = true, + recipientBuilder = UserOrGuestOrRequest + ) + ) + + BaseConfigBuilder(boolOverrides :+ safetyLevelOverride).build("VisibilityDeciders") + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/VisibilityExperimentsConfig.scala b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/VisibilityExperimentsConfig.scala new file mode 100644 index 000000000..faf25ee2a --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/VisibilityExperimentsConfig.scala @@ -0,0 +1,33 @@ +package com.twitter.visibility.configapi.configs + +import com.twitter.timelines.configapi.Config +import com.twitter.visibility.configapi.params.RuleParams._ +import com.twitter.visibility.configapi.params.VisibilityExperiments._ +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.SafetyLevel._ + +private[visibility] object VisibilityExperimentsConfig { + import ExperimentsHelper._ + + val TestExperimentConfig: Config = mkABExperimentConfig(TestExperiment, TestHoldbackParam) + + val NotGraduatedUserLabelRuleHoldbackExperimentConfig: Config = + mkABExperimentConfig( + NotGraduatedUserLabelRuleExperiment, + NotGraduatedUserLabelRuleHoldbackExperimentParam + ) + + def config(safetyLevel: SafetyLevel): Seq[Config] = { + + val experimentConfigs = safetyLevel match { + + case Test => + Seq(TestExperimentConfig) + + case _ => Seq(NotGraduatedUserLabelRuleHoldbackExperimentConfig) + } + + experimentConfigs + } + +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/VisibilityFeatureSwitches.scala b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/VisibilityFeatureSwitches.scala new file mode 100644 index 000000000..2fa33cd57 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/VisibilityFeatureSwitches.scala @@ -0,0 +1,74 @@ +package com.twitter.visibility.configapi.configs + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.logging.Logger +import com.twitter.timelines.configapi._ +import com.twitter.util.Time +import com.twitter.visibility.configapi.params.FSEnumRuleParam +import com.twitter.visibility.configapi.params.FSRuleParams._ + +private[visibility] object VisibilityFeatureSwitches { + + val booleanFsOverrides: Seq[OptionalOverride[Boolean]] = + FeatureSwitchOverrideUtil.getBooleanFSOverrides( + AgeGatingAdultContentExperimentRuleEnabledParam, + CommunityTweetCommunityUnavailableLimitedActionsRulesEnabledParam, + CommunityTweetDropProtectedRuleEnabledParam, + CommunityTweetDropRuleEnabledParam, + CommunityTweetLimitedActionsRulesEnabledParam, + CommunityTweetMemberRemovedLimitedActionsRulesEnabledParam, + CommunityTweetNonMemberLimitedActionsRuleEnabledParam, + NsfwAgeBasedDropRulesHoldbackParam, + SkipTweetDetailLimitedEngagementRuleEnabledParam, + StaleTweetLimitedActionsRulesEnabledParam, + TrustedFriendsTweetLimitedEngagementsRuleEnabledParam, + FosnrFallbackDropRulesEnabledParam, + FosnrRulesEnabledParam + ) + + val doubleFsOverrides: Seq[OptionalOverride[Double]] = + FeatureSwitchOverrideUtil.getBoundedDoubleFSOverrides( + HighSpammyTweetContentScoreSearchTopProdTweetLabelDropRuleThresholdParam, + HighSpammyTweetContentScoreSearchLatestProdTweetLabelDropRuleThresholdParam, + HighSpammyTweetContentScoreTrendTopTweetLabelDropRuleThresholdParam, + HighSpammyTweetContentScoreTrendLatestTweetLabelDropRuleThresholdParam, + HighSpammyTweetContentScoreConvoDownrankAbusiveQualityThresholdParam, + HighToxicityModelScoreSpaceThresholdParam, + AdAvoidanceHighToxicityModelScoreThresholdParam, + AdAvoidanceReportedTweetModelScoreThresholdParam, + ) + + val timeFsOverrides: Seq[OptionalOverride[Time]] = + FeatureSwitchOverrideUtil.getTimeFromStringFSOverrides() + + val stringSeqFeatureSwitchOverrides: Seq[OptionalOverride[Seq[String]]] = + FeatureSwitchOverrideUtil.getStringSeqFSOverrides( + CountrySpecificNsfwContentGatingCountriesParam, + AgeGatingAdultContentExperimentCountriesParam, + CardUriRootDomainDenyListParam + ) + + val enumFsParams: Seq[FSEnumRuleParam[_ <: Enumeration]] = Seq() + + val mkOptionalEnumFsOverrides: (StatsReceiver, Logger) => Seq[OptionalOverride[_]] = { + (statsReceiver: StatsReceiver, logger: Logger) => + FeatureSwitchOverrideUtil.getEnumFSOverrides( + statsReceiver, + logger, + enumFsParams: _* + ) + } + + def overrides(statsReceiver: StatsReceiver, logger: Logger): Seq[OptionalOverride[_]] = { + val enumOverrides = mkOptionalEnumFsOverrides(statsReceiver, logger) + booleanFsOverrides ++ + doubleFsOverrides ++ + timeFsOverrides ++ + stringSeqFeatureSwitchOverrides ++ + enumOverrides + } + + def config(statsReceiver: StatsReceiver, logger: Logger): BaseConfig = + BaseConfigBuilder(overrides(statsReceiver.scope("features_switches"), logger)) + .build("VisibilityFeatureSwitches") +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/overrides/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/overrides/BUILD new file mode 100644 index 000000000..b59bedb48 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/overrides/BUILD @@ -0,0 +1,9 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "decider", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/overrides/VisibilityLibraryDeciderOverrides.scala b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/overrides/VisibilityLibraryDeciderOverrides.scala new file mode 100644 index 000000000..842309be7 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/configs/overrides/VisibilityLibraryDeciderOverrides.scala @@ -0,0 +1,24 @@ +package com.twitter.visibility.configapi.configs.overrides + +import com.twitter.decider.LocalOverrides + +object VisibilityLibraryDeciderOverrides + extends LocalOverrides.Namespace("visibility-library", "") { + + val EnableLocalizedTombstoneOnVisibilityResults = feature( + "visibility_library_enable_localized_tombstones_on_visibility_results") + + val EnableLocalizedInterstitialGenerator: LocalOverrides.Override = + feature("visibility_library_enable_localized_interstitial_generator") + + val EnableInnerQuotedTweetViewerBlocksAuthorInterstitialRule: LocalOverrides.Override = + feature("visibility_library_enable_inner_quoted_tweet_viewer_blocks_author_interstitial_rule") + val EnableInnerQuotedTweetViewerMutesAuthorInterstitialRule: LocalOverrides.Override = + feature("visibility_library_enable_inner_quoted_tweet_viewer_mutes_author_interstitial_rule") + + val EnableBackendLimitedActions: LocalOverrides.Override = + feature("visibility_library_enable_backend_limited_actions") + + val disableLegacyInterstitialFilteredReason: LocalOverrides.Override = feature( + "visibility_library_disable_legacy_interstitial_filtered_reason") +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/BUILD new file mode 100644 index 000000000..b2855c4c6 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/BUILD @@ -0,0 +1,15 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "configapi/configapi-core", + "finagle/finagle-stats", + "visibility/common/src/main/scala/com/twitter/visibility/common:model_thresholds", + ], + exports = [ + "visibility/common/src/main/scala/com/twitter/visibility/common:model_thresholds", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/FSRuleParams.scala b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/FSRuleParams.scala new file mode 100644 index 000000000..270db45aa --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/FSRuleParams.scala @@ -0,0 +1,213 @@ +package com.twitter.visibility.configapi.params + +import com.twitter.timelines.configapi.Bounded +import com.twitter.timelines.configapi.FSBoundedParam +import com.twitter.timelines.configapi.FSName +import com.twitter.timelines.configapi.FeatureName +import com.twitter.timelines.configapi.HasTimeConversion +import com.twitter.timelines.configapi.TimeConversion +import com.twitter.util.Time +import com.twitter.visibility.common.ModelScoreThresholds + +private[visibility] object FeatureSwitchKey extends Enumeration { + type FeatureSwitchKey = String + + final val HighSpammyTweetContentScoreSearchTopProdTweetLabelDropFuleThreshold = + "high_spammy_tweet_content_score_search_top_prod_tweet_label_drop_rule_threshold" + final val HighSpammyTweetContentScoreSearchLatestProdTweetLabelDropRuleThreshold = + "high_spammy_tweet_content_score_search_latest_prod_tweet_label_drop_rule_threshold" + final val HighSpammyTweetContentScoreTrendTopTweetLabelDropRuleThreshold = + "high_spammy_tweet_content_score_trend_top_tweet_label_drop_rule_threshold" + final val HighSpammyTweetContentScoreTrendLatestTweetLabelDropRuleThreshold = + "high_spammy_tweet_content_score_trend_latest_tweet_label_drop_rule_threshold" + final val HighSpammyTweetContentScoreConvoDownrankAbusiveQualityThreshold = + "high_spammy_tweet_content_score_convos_downranking_abusive_quality_threshold" + + final val NsfwAgeBasedDropRulesHoldbackParam = + "nsfw_age_based_drop_rules_holdback" + + final val CommunityTweetDropRuleEnabled = + "community_tweet_drop_rule_enabled" + final val CommunityTweetDropProtectedRuleEnabled = + "community_tweet_drop_protected_rule_enabled" + final val CommunityTweetLimitedActionsRulesEnabled = + "community_tweet_limited_actions_rules_enabled" + final val CommunityTweetMemberRemovedLimitedActionsRulesEnabled = + "community_tweet_member_removed_limited_actions_rules_enabled" + final val CommunityTweetCommunityUnavailableLimitedActionsRulesEnabled = + "community_tweet_community_unavailable_limited_actions_rules_enabled" + final val CommunityTweetNonMemberLimitedActionsRuleEnabled = + "community_tweet_non_member_limited_actions_rule_enabled" + + final val TrustedFriendsTweetLimitedEngagementsRuleEnabled = + "trusted_friends_tweet_limited_engagements_rule_enabled" + + final val CountrySpecificNsfwContentGatingCountries = + "country_specific_nsfw_content_gating_countries" + + final val AgeGatingAdultContentExperimentCountries = + "age_gating_adult_content_experiment_countries" + final val AgeGatingAdultContentExperimentEnabled = + "age_gating_adult_content_experiment_enabled" + + final val HighToxicityModelScoreSpaceThreshold = + "high_toxicity_model_score_space_threshold" + + final val CardUriRootDomainDenyList = "card_uri_root_domain_deny_list" + + final val SkipTweetDetailLimitedEngagementsRuleEnabled = + "skip_tweet_detail_limited_engagements_rule_enabled" + + final val AdAvoidanceHighToxicityModelScoreThreshold = + "ad_avoidance_model_thresholds_high_toxicity_model" + final val AdAvoidanceReportedTweetModelScoreThreshold = + "ad_avoidance_model_thresholds_reported_tweet_model" + + final val StaleTweetLimitedActionsRulesEnabled = + "stale_tweet_limited_actions_rules_enabled" + + final val FosnrFallbackDropRulesEnabled = + "freedom_of_speech_not_reach_fallback_drop_rules_enabled" + final val FosnrRulesEnabled = + "freedom_of_speech_not_reach_rules_enabled" +} + +abstract class FSRuleParam[T](override val name: FeatureName, override val default: T) + extends RuleParam(default) + with FSName + +abstract class FSBoundedRuleParam[T]( + override val name: FeatureName, + override val default: T, + override val min: T, + override val max: T +)( + implicit override val ordering: Ordering[T]) + extends RuleParam(default) + with Bounded[T] + with FSName + +abstract class FSTimeRuleParam[T]( + override val name: FeatureName, + override val default: Time, + override val timeConversion: TimeConversion[T]) + extends RuleParam(default) + with HasTimeConversion[T] + with FSName + +abstract class FSEnumRuleParam[T <: Enumeration]( + override val name: FeatureName, + override val default: T#Value, + override val enum: T) + extends EnumRuleParam(default, enum) + with FSName + +private[visibility] object FSRuleParams { + object HighSpammyTweetContentScoreSearchTopProdTweetLabelDropRuleThresholdParam + extends FSBoundedParam( + FeatureSwitchKey.HighSpammyTweetContentScoreSearchTopProdTweetLabelDropFuleThreshold, + default = ModelScoreThresholds.HighSpammyTweetContentScoreDefaultThreshold, + min = 0, + max = 1) + object HighSpammyTweetContentScoreSearchLatestProdTweetLabelDropRuleThresholdParam + extends FSBoundedParam( + FeatureSwitchKey.HighSpammyTweetContentScoreSearchLatestProdTweetLabelDropRuleThreshold, + default = ModelScoreThresholds.HighSpammyTweetContentScoreDefaultThreshold, + min = 0, + max = 1) + object HighSpammyTweetContentScoreTrendTopTweetLabelDropRuleThresholdParam + extends FSBoundedParam( + FeatureSwitchKey.HighSpammyTweetContentScoreTrendTopTweetLabelDropRuleThreshold, + default = ModelScoreThresholds.HighSpammyTweetContentScoreDefaultThreshold, + min = 0, + max = 1) + object HighSpammyTweetContentScoreTrendLatestTweetLabelDropRuleThresholdParam + extends FSBoundedParam( + FeatureSwitchKey.HighSpammyTweetContentScoreTrendLatestTweetLabelDropRuleThreshold, + default = ModelScoreThresholds.HighSpammyTweetContentScoreDefaultThreshold, + min = 0, + max = 1) + object HighSpammyTweetContentScoreConvoDownrankAbusiveQualityThresholdParam + extends FSBoundedParam( + FeatureSwitchKey.HighSpammyTweetContentScoreConvoDownrankAbusiveQualityThreshold, + default = ModelScoreThresholds.HighSpammyTweetContentScoreDefaultThreshold, + min = 0, + max = 1) + + object CommunityTweetDropRuleEnabledParam + extends FSRuleParam(FeatureSwitchKey.CommunityTweetDropRuleEnabled, true) + + object CommunityTweetDropProtectedRuleEnabledParam + extends FSRuleParam(FeatureSwitchKey.CommunityTweetDropProtectedRuleEnabled, true) + + object CommunityTweetLimitedActionsRulesEnabledParam + extends FSRuleParam(FeatureSwitchKey.CommunityTweetLimitedActionsRulesEnabled, false) + + object CommunityTweetMemberRemovedLimitedActionsRulesEnabledParam + extends FSRuleParam( + FeatureSwitchKey.CommunityTweetMemberRemovedLimitedActionsRulesEnabled, + false) + + object CommunityTweetCommunityUnavailableLimitedActionsRulesEnabledParam + extends FSRuleParam( + FeatureSwitchKey.CommunityTweetCommunityUnavailableLimitedActionsRulesEnabled, + false) + + object CommunityTweetNonMemberLimitedActionsRuleEnabledParam + extends FSRuleParam(FeatureSwitchKey.CommunityTweetNonMemberLimitedActionsRuleEnabled, false) + + object TrustedFriendsTweetLimitedEngagementsRuleEnabledParam + extends FSRuleParam(FeatureSwitchKey.TrustedFriendsTweetLimitedEngagementsRuleEnabled, false) + + object SkipTweetDetailLimitedEngagementRuleEnabledParam + extends FSRuleParam(FeatureSwitchKey.SkipTweetDetailLimitedEngagementsRuleEnabled, false) + + + object NsfwAgeBasedDropRulesHoldbackParam + extends FSRuleParam(FeatureSwitchKey.NsfwAgeBasedDropRulesHoldbackParam, true) + + object CountrySpecificNsfwContentGatingCountriesParam + extends FSRuleParam[Seq[String]]( + FeatureSwitchKey.CountrySpecificNsfwContentGatingCountries, + default = Seq("au")) + + object AgeGatingAdultContentExperimentCountriesParam + extends FSRuleParam[Seq[String]]( + FeatureSwitchKey.AgeGatingAdultContentExperimentCountries, + default = Seq.empty) + object AgeGatingAdultContentExperimentRuleEnabledParam + extends FSRuleParam(FeatureSwitchKey.AgeGatingAdultContentExperimentEnabled, default = false) + + object HighToxicityModelScoreSpaceThresholdParam + extends FSBoundedParam( + FeatureSwitchKey.HighToxicityModelScoreSpaceThreshold, + default = ModelScoreThresholds.HighToxicityModelScoreSpaceDefaultThreshold, + min = 0, + max = 1) + + object CardUriRootDomainDenyListParam + extends FSRuleParam[Seq[String]]( + FeatureSwitchKey.CardUriRootDomainDenyList, + default = Seq.empty) + + object AdAvoidanceHighToxicityModelScoreThresholdParam + extends FSBoundedParam( + FeatureSwitchKey.AdAvoidanceHighToxicityModelScoreThreshold, + default = ModelScoreThresholds.AdAvoidanceHighToxicityModelScoreDefaultThreshold, + min = 0, + max = 1) + + object AdAvoidanceReportedTweetModelScoreThresholdParam + extends FSBoundedParam( + FeatureSwitchKey.AdAvoidanceReportedTweetModelScoreThreshold, + default = ModelScoreThresholds.AdAvoidanceReportedTweetModelScoreDefaultThreshold, + min = 0, + max = 1) + + object StaleTweetLimitedActionsRulesEnabledParam + extends FSRuleParam(FeatureSwitchKey.StaleTweetLimitedActionsRulesEnabled, false) + + object FosnrFallbackDropRulesEnabledParam + extends FSRuleParam(FeatureSwitchKey.FosnrFallbackDropRulesEnabled, false) + object FosnrRulesEnabledParam extends FSRuleParam(FeatureSwitchKey.FosnrRulesEnabled, true) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/GlobalParams.scala b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/GlobalParams.scala new file mode 100644 index 000000000..c6a960486 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/GlobalParams.scala @@ -0,0 +1,11 @@ +package com.twitter.visibility.configapi.params + +import com.twitter.timelines.configapi.Param + +abstract class GlobalParam[T](override val default: T) extends Param(default) { + override val statName: String = s"GlobalParam/${this.getClass.getSimpleName}" +} + +private[visibility] object GlobalParams { + object EnableFetchingLabelMap extends GlobalParam(false) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/LabelSourceParams.scala b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/LabelSourceParams.scala new file mode 100644 index 000000000..0f0eaaa2d --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/LabelSourceParams.scala @@ -0,0 +1,15 @@ +package com.twitter.visibility.configapi.params + +import com.twitter.timelines.configapi.Param + +abstract class LabelSourceParam(override val default: Boolean) extends Param(default) { + override val statName: String = s"LabelSourceParam/${this.getClass.getSimpleName}" +} + +private[visibility] object LabelSourceParams { + object FilterLabelsFromBot7174Param extends LabelSourceParam(false) + + object FilterTweetsSmyteAutomationParamA extends LabelSourceParam(false) + object FilterTweetsSmyteAutomationParamB extends LabelSourceParam(false) + object FilterTweetsSmyteAutomationParamAB extends LabelSourceParam(false) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/RuleParams.scala b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/RuleParams.scala new file mode 100644 index 000000000..44c7797b9 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/RuleParams.scala @@ -0,0 +1,164 @@ +package com.twitter.visibility.configapi.params + +import com.twitter.timelines.configapi.EnumParam +import com.twitter.timelines.configapi.Param + +abstract class RuleParam[T](override val default: T) extends Param(default) { + override val statName: String = s"RuleParam/${this.getClass.getSimpleName}" +} + +abstract class EnumRuleParam[T <: Enumeration](override val default: T#Value, override val enum: T) + extends EnumParam(default, enum) { + override val statName: String = s"RuleParam/${this.getClass.getSimpleName}" +} + +private[visibility] object RuleParams { + object True extends RuleParam(true) + object False extends RuleParam(false) + + object TestHoldbackParam extends RuleParam(true) + + object TweetConversationControlEnabledParam extends RuleParam(default = false) + + object EnableLimitRepliesFollowersConversationRule extends RuleParam(default = false) + + object CommunityTweetsEnabledParam extends RuleParam(default = false) + + object DropCommunityTweetWithUndefinedCommunityRuleEnabledParam extends RuleParam(default = false) + + object EnableHighPSpammyTweetScoreSearchTweetLabelDropRuleParam extends RuleParam(false) + + object EnableSmyteSpamTweetRuleParam extends RuleParam(false) + + object EnableHighSpammyTweetContentScoreSearchLatestTweetLabelDropRuleParam + extends RuleParam(false) + + object EnableHighSpammyTweetContentScoreSearchTopTweetLabelDropRuleParam extends RuleParam(false) + + object NotGraduatedUserLabelRuleHoldbackExperimentParam extends RuleParam(default = false) + + object EnableGoreAndViolenceTopicHighRecallTweetLabelRule extends RuleParam(default = false) + + object EnableBlinkBadDownrankingRuleParam extends RuleParam(false) + object EnableBlinkWorstDownrankingRuleParam extends RuleParam(false) + + object EnableHighSpammyTweetContentScoreTrendsTopTweetLabelDropRuleParam + extends RuleParam(default = false) + + object EnableHighSpammyTweetContentScoreTrendsLatestTweetLabelDropRuleParam + extends RuleParam(default = false) + + object EnableCopypastaSpamDownrankConvosAbusiveQualityRule extends RuleParam(default = false) + object EnableCopypastaSpamSearchDropRule extends RuleParam(default = false) + + object EnableSpammyUserModelTweetDropRuleParam extends RuleParam(default = false) + + object EnableAvoidNsfwRulesParam extends RuleParam(false) + + object EnableReportedTweetInterstitialRule extends RuleParam(default = false) + + object EnableReportedTweetInterstitialSearchRule extends RuleParam(default = false) + + object EnableDropExclusiveTweetContentRule extends RuleParam(default = false) + + object EnableDropExclusiveTweetContentRuleFailClosed extends RuleParam(default = false) + + object EnableTombstoneExclusiveQtProfileTimelineParam extends RuleParam(default = false) + + object EnableDropAllExclusiveTweetsRuleParam extends RuleParam(false) + object EnableDropAllExclusiveTweetsRuleFailClosedParam extends RuleParam(false) + + object EnableDownrankSpamReplySectioningRuleParam extends RuleParam(default = false) + object EnableNsfwTextSectioningRuleParam extends RuleParam(default = false) + + object EnableSearchIpiSafeSearchWithoutUserInQueryDropRule extends RuleParam(false) + + object PromotedTweetHealthEnforcementHoldback extends RuleParam(default = true) + object EnableTimelineHomePromotedTweetHealthEnforcementRules extends RuleParam(default = false) + + object EnableMutedKeywordFilteringSpaceTitleNotificationsRuleParam extends RuleParam(false) + + object EnableDropTweetsWithGeoRestrictedMediaRuleParam extends RuleParam(default = false) + + object EnableDropAllTrustedFriendsTweetsRuleParam extends RuleParam(false) + object EnableDropTrustedFriendsTweetContentRuleParam extends RuleParam(false) + + object EnableDropAllCollabInvitationTweetsRuleParam extends RuleParam(false) + + object EnableNsfwTextTopicsDropRuleParam extends RuleParam(false) + + object EnableLikelyIvsUserLabelDropRule extends RuleParam(false) + + object EnableCardUriRootDomainCardDenylistRule extends RuleParam(false) + object EnableCommunityNonMemberPollCardRule extends RuleParam(false) + object EnableCommunityNonMemberPollCardRuleFailClosed extends RuleParam(false) + + object EnableExperimentalNudgeEnabledParam extends RuleParam(false) + + object NsfwHighPrecisionUserLabelAvoidTweetRuleEnabledParam extends RuleParam(default = false) + + object EnableNewAdAvoidanceRulesParam extends RuleParam(false) + + object EnableNsfaHighRecallAdAvoidanceParam extends RuleParam(false) + + object EnableNsfaKeywordsHighPrecisionAdAvoidanceParam extends RuleParam(false) + + object EnableStaleTweetDropRuleParam extends RuleParam(false) + object EnableStaleTweetDropRuleFailClosedParam extends RuleParam(false) + + object EnableDeleteStateTweetRulesParam extends RuleParam(default = false) + + object EnableSpacesSharingNsfwDropRulesParam extends RuleParam(default = true) + + object EnableViewerIsSoftUserDropRuleParam extends RuleParam(default = false) + + object EnablePdnaQuotedTweetTombstoneRuleParam extends RuleParam(default = true) + object EnableSpamQuotedTweetTombstoneRuleParam extends RuleParam(default = true) + + object EnableNsfwHpQuotedTweetDropRuleParam extends RuleParam(default = false) + object EnableNsfwHpQuotedTweetTombstoneRuleParam extends RuleParam(default = false) + + object EnableInnerQuotedTweetViewerBlocksAuthorInterstitialRuleParam + extends RuleParam(default = false) + object EnableInnerQuotedTweetViewerMutesAuthorInterstitialRuleParam + extends RuleParam(default = false) + + + object EnableNewSensitiveMediaSettingsInterstitialsHomeTimelineRulesParam extends RuleParam(false) + + object EnableNewSensitiveMediaSettingsInterstitialsConversationRulesParam extends RuleParam(false) + + object EnableNewSensitiveMediaSettingsInterstitialsProfileTimelineRulesParam + extends RuleParam(false) + + object EnableNewSensitiveMediaSettingsInterstitialsTweetDetailRulesParam extends RuleParam(false) + + object EnableLegacySensitiveMediaHomeTimelineRulesParam extends RuleParam(true) + + object EnableLegacySensitiveMediaConversationRulesParam extends RuleParam(true) + + object EnableLegacySensitiveMediaProfileTimelineRulesParam extends RuleParam(true) + + object EnableLegacySensitiveMediaTweetDetailRulesParam extends RuleParam(true) + + object EnableLegacySensitiveMediaDirectMessagesRulesParam extends RuleParam(true) + + object EnableToxicReplyFilteringConversationRulesParam extends RuleParam(false) + object EnableToxicReplyFilteringNotificationsRulesParam extends RuleParam(false) + + object EnableSearchQueryMatchesTweetAuthorConditionParam extends RuleParam(default = false) + + object EnableSearchBasicBlockMuteRulesParam extends RuleParam(default = false) + + object EnableAbusiveBehaviorDropRuleParam extends RuleParam(default = false) + object EnableAbusiveBehaviorInterstitialRuleParam extends RuleParam(default = false) + object EnableAbusiveBehaviorLimitedEngagementsRuleParam extends RuleParam(default = false) + + object EnableNotGraduatedDownrankConvosAbusiveQualityRuleParam extends RuleParam(default = false) + object EnableNotGraduatedSearchDropRuleParam extends RuleParam(default = false) + object EnableNotGraduatedDropRuleParam extends RuleParam(default = false) + + object EnableFosnrRuleParam extends RuleParam(default = false) + + object EnableAuthorBlocksViewerDropRuleParam extends RuleParam(default = false) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/SafetyLevelParams.scala b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/SafetyLevelParams.scala new file mode 100644 index 000000000..a8c7d9f51 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/SafetyLevelParams.scala @@ -0,0 +1,214 @@ +package com.twitter.visibility.configapi.params + +import com.twitter.timelines.configapi.Param + +abstract class SafetyLevelParam(override val default: Boolean) extends Param(default) { + override val statName: String = s"SafetyLevelParam/${this.getClass.getSimpleName}" +} + +private[visibility] object SafetyLevelParams { + object EnableAccessInternalPromotedContentSafetyLevelParam extends SafetyLevelParam(false) + object EnableAdsBusinessSettingsSafetyLevelParam extends SafetyLevelParam(false) + object EnableAdsCampaignSafetyLevelParam extends SafetyLevelParam(false) + object EnableAdsManagerSafetyLevelParam extends SafetyLevelParam(false) + object EnableAdsReportingDashboardSafetyLevelParam extends SafetyLevelParam(false) + object EnableAllSubscribedListsSafetyLevelParam extends SafetyLevelParam(false) + object EnableAppealsSafetyLevelParam extends SafetyLevelParam(false) + object EnableArticleTweetTimelineSafetyLevelParam extends SafetyLevelParam(false) + object EnableBaseQigSafetyLevelParam extends SafetyLevelParam(false) + object EnableBirdwatchNoteAuthorSafetyLevel extends SafetyLevelParam(false) + object EnableBirdwatchNoteTweetsTimelineSafetyLevel extends SafetyLevelParam(false) + object EnableBirdwatchNeedsYourHelpNotificationsSafetyLevel extends SafetyLevelParam(false) + object EnableBlockMuteUsersTimelineSafetyLevelParam extends SafetyLevelParam(false) + object EnableBrandSafetySafetyLevelParam extends SafetyLevelParam(false) + object EnableCardPollVotingSafetyLevelParam extends SafetyLevelParam(false) + object EnableCardsServiceSafetyLevelParam extends SafetyLevelParam(false) + object EnableCommunitiesSafetyLevelParam extends SafetyLevelParam(false) + object EnableContentControlToolInstallSafetyLevelParam extends SafetyLevelParam(false) + object EnableConversationFocalPrehydrationSafetyLevelParam extends SafetyLevelParam(false) + object EnableConversationFocalTweetSafetyLevelParam extends SafetyLevelParam(false) + object EnableConversationInjectedTweetSafetyLevelParam extends SafetyLevelParam(false) + object EnableConversationReplySafetyLevelParam extends SafetyLevelParam(false) + object EnableCuratedTrendsRepresentativeTweet extends SafetyLevelParam(default = false) + object EnableCurationPolicyViolations extends SafetyLevelParam(default = false) + object EnableDevPlatformGetListTweetsSafetyLevelParam extends SafetyLevelParam(false) + object EnableDESFollowingAndFollowersUserListSafetyLevelParam extends SafetyLevelParam(false) + object EnableDESHomeTimelineSafetyLevelParam extends SafetyLevelParam(false) + object EnableDESQuoteTweetTimelineSafetyLevelParam extends SafetyLevelParam(false) + object EnableDESRealtimeSafetyLevelParam extends SafetyLevelParam(false) + object EnableDESRealtimeSpamEnrichmentSafetyLevelParam extends SafetyLevelParam(false) + object EnableDESRealtimeTweetFilterSafetyLevelParam extends SafetyLevelParam(false) + object EnableDESRetweetingUsersSafetyLevelParam extends SafetyLevelParam(false) + object EnableDesTweetDetailSafetyLevelParam extends SafetyLevelParam(false) + object EnableDESTweetLikingUsersSafetyLevelParam extends SafetyLevelParam(false) + object EnableDESUserBookmarksSafetyLevelParam extends SafetyLevelParam(false) + object EnableDESUserLikedTweetSafetyLevelParam extends SafetyLevelParam(false) + object EnableDESUserMentionsSafetyLevelParam extends SafetyLevelParam(false) + object EnableDESUserTweetsSafetyLevelParam extends SafetyLevelParam(false) + object EnableDevPlatformComplianceStreamSafetyLevelParam extends SafetyLevelParam(false) + object EnableDirectMessagesSafetyLevelParam extends SafetyLevelParam(false) + object EnableDirectMessagesConversationListSafetyLevelParam extends SafetyLevelParam(false) + object EnableDirectMessagesConversationTimelineSafetyLevelParam extends SafetyLevelParam(false) + object EnableDirectMessagesInboxSafetyLevelParam extends SafetyLevelParam(false) + object EnableDirectMessagesMutedUsersSafetyLevelParam extends SafetyLevelParam(false) + object EnableDirectMessagesPinnedSafetyLevelParam extends SafetyLevelParam(false) + object EnableDirectMessagesSearchSafetyLevelParam extends SafetyLevelParam(false) + object EnableEditHistoryTimelineSafetyLevelParam extends SafetyLevelParam(false) + object EnableElevatedQuoteTweetTimelineSafetyLevelParam extends SafetyLevelParam(false) + object EnableEmbeddedTweetSafetyLevelParam extends SafetyLevelParam(false) + object EnableEmbedsPublicInterestNoticeSafetyLevelParam extends SafetyLevelParam(false) + object EnableEmbedTweetMarkupSafetyLevelParam extends SafetyLevelParam(false) + object EnableWritePathLimitedActionsEnforcementSafetyLevelParam extends SafetyLevelParam(false) + object EnableFilterAllSafetyLevelParam extends SafetyLevelParam(false) + object EnableFilterAllPlaceholderSafetyLevelParam extends SafetyLevelParam(false) + object EnableFilterDefaultSafetyLevelParam extends SafetyLevelParam(false) + object EnableFilterNoneSafetyLevelParam extends SafetyLevelParam(false) + object EnableFollowedTopicsTimelineSafetyLevelParam extends SafetyLevelParam(false) + object EnableFollowerConnectionsSafetyLevelParam extends SafetyLevelParam(false) + object EnableFollowingAndFollowersUserListSafetyLevelParam extends SafetyLevelParam(false) + object EnableForDevelopmentOnlySafetyLevelParam extends SafetyLevelParam(false) + object EnableFriendsFollowingListSafetyLevelParam extends SafetyLevelParam(false) + object EnableGraphqlDefaultSafetyLevelParam extends SafetyLevelParam(false) + object EnableGryphonDecksAndColumnsSafetyLevelParam extends SafetyLevelParam(false) + object EnableHumanizationNudgeSafetyLevelParam extends SafetyLevelParam(false) + object EnableKitchenSinkDevelopmentSafetyLevelParam extends SafetyLevelParam(false) + object EnableListHeaderSafetyLevelParam extends SafetyLevelParam(false) + object EnableListMembershipsSafetyLevelParam extends SafetyLevelParam(false) + object EnableListOwnershipsSafetyLevelParam extends SafetyLevelParam(false) + object EnableListRecommendationsSafetyLevelParam extends SafetyLevelParam(false) + object EnableListSearchSafetyLevelParam extends SafetyLevelParam(false) + object EnableListSubscriptionsSafetyLevelParam extends SafetyLevelParam(false) + object EnableLivePipelineEngagementCountsSafetyLevelParam extends SafetyLevelParam(false) + object EnableLiveVideoTimelineSafetyLevelParam extends SafetyLevelParam(false) + object EnableMagicRecsAggressiveSafetyLevelParam extends SafetyLevelParam(false) + object EnableMagicRecsAggressiveV2SafetyLevelParam extends SafetyLevelParam(false) + object EnableMagicRecsSafetyLevelParam extends SafetyLevelParam(false) + object EnableMagicRecsV2SafetyLevelParam extends SafetyLevelParam(false) + object EnableMinimalSafetyLevelParam extends SafetyLevelParam(false) + object EnableModeratedTweetsTimelineSafetyLevelParam extends SafetyLevelParam(false) + object EnableMomentsSafetyLevelParam extends SafetyLevelParam(false) + object EnableNearbySafetyLevelParam extends SafetyLevelParam(false) + object EnableNewUserExperienceSafetyLevelParam extends SafetyLevelParam(false) + object EnableNotificationsIbisSafetyLevelParam extends SafetyLevelParam(false) + object EnableNotificationsPlatformSafetyLevelParam extends SafetyLevelParam(false) + object EnableNotificationsPlatformPushSafetyLevelParam extends SafetyLevelParam(false) + object EnableNotificationsQigSafetyLevelParam extends SafetyLevelParam(false) + object EnableNotificationsReadSafetyLevelParam extends SafetyLevelParam(false) + object EnableNotificationsTimelineDeviceFollowSafetyLevelParam extends SafetyLevelParam(false) + object EnableNotificationsWriteSafetyLevelParam extends SafetyLevelParam(false) + object EnableNotificationsWriterTweetHydratorSafetyLevelParam extends SafetyLevelParam(false) + object EnableNotificationsWriterV2SafetyLevelParam extends SafetyLevelParam(false) + object EnableQuickPromoteTweetEligibilitySafetyLevelParam extends SafetyLevelParam(false) + object EnableQuoteTweetTimelineSafetyLevelParam extends SafetyLevelParam(false) + object EnableRecommendationsSafetyLevelParam extends SafetyLevelParam(false) + object EnableRecommendationsWithoutNsfaSafetyLevelParam extends SafetyLevelParam(false) + object EnableRecosVideoSafetyLevelParam extends SafetyLevelParam(false) + object EnableRecosWritePathSafetyLevelParam extends SafetyLevelParam(false) + object EnableRepliesGroupingSafetyLevelParam extends SafetyLevelParam(false) + object EnableReportCenterSafetyLevelParam extends SafetyLevelParam(false) + object EnableReturningUserExperienceFocalTweetSafetyLevelParam extends SafetyLevelParam(false) + object EnableReturningUserExperienceSafetyLevelParam extends SafetyLevelParam(false) + object EnableRevenueSafetyLevelParam extends SafetyLevelParam(false) + object EnableRitoActionedTweetTimelineParam extends SafetyLevelParam(false) + object EnableSafeSearchMinimalSafetyLevelParam extends SafetyLevelParam(false) + object EnableSafeSearchStrictSafetyLevelParam extends SafetyLevelParam(false) + object EnableSearchHydrationSafetyLevelParam extends SafetyLevelParam(false) + object EnableSearchLatestSafetyLevelParam extends SafetyLevelParam(false) + object EnableSearchTopSafetyLevelParam extends SafetyLevelParam(false) + object EnableSearchTopQigSafetyLevelParam extends SafetyLevelParam(false) + object EnableSearchPhotoSafetyLevelParam extends SafetyLevelParam(false) + object SearchTrendTakeoverPromotedTweetSafetyLevelParam extends SafetyLevelParam(false) + object EnableSearchVideoSafetyLevelParam extends SafetyLevelParam(false) + object EnableSearchBlenderUserRulesSafetyLevelParam extends SafetyLevelParam(false) + object EnableSearchLatestUserRulesSafetyLevelParam extends SafetyLevelParam(false) + object EnableSearchPeopleSearchResultPageSafetyLevelParam extends SafetyLevelParam(false) + object EnableSearchPeopleTypeaheadSafetyLevelParam extends SafetyLevelParam(false) + object EnableUserSearchSrpSafetyLevelParam extends SafetyLevelParam(false) + object EnableUserSearchTypeaheadSafetyLevelParam extends SafetyLevelParam(false) + object EnableSearchMixerSrpMinimalSafetyLevelParam extends SafetyLevelParam(false) + object EnableSearchMixerSrpStrictSafetyLevelParam extends SafetyLevelParam(false) + object EnableShoppingManagerSpyModeSafetyLevelParam extends SafetyLevelParam(false) + object EnableSignalsReactionsSafetyLevelParam extends SafetyLevelParam(false) + object EnableSignalsTweetReactingUsersSafetyLevelParam extends SafetyLevelParam(false) + object EnableSocialProofSafetyLevelParam extends SafetyLevelParam(false) + object EnableSoftInterventionPivotSafetyLevelParam extends SafetyLevelParam(false) + object EnableSpaceFleetlineSafetyLevelParam extends SafetyLevelParam(false) + object EnableSpaceHomeTimelineUprankingSafetyLevelParam extends SafetyLevelParam(false) + object EnableSpaceJoinScreenSafetyLevelParam extends SafetyLevelParam(false) + object EnableSpaceNotificationsSafetyLevelParam extends SafetyLevelParam(false) + object EnableSpacesSafetyLevelParam extends SafetyLevelParam(false) + object EnableSpacesParticipantsSafetyLevelParam extends SafetyLevelParam(false) + object EnableSpaceNotificationSafetyLevelParam extends SafetyLevelParam(false) + object EnableSpacesSellerApplicationStatusSafetyLevelParam extends SafetyLevelParam(false) + object EnableSpacesSharingSafetyLevelParam extends SafetyLevelParam(false) + object EnableSpaceTweetAvatarHomeTimelineSafetyLevelParam extends SafetyLevelParam(false) + object EnableStickersTimelineSafetyLevelParam extends SafetyLevelParam(false) + object EnableStratoExtLimitedEngagementsSafetyLevelParam extends SafetyLevelParam(false) + object EnableStreamServicesSafetyLevelParam extends SafetyLevelParam(false) + object EnableSuperFollowerConnectionsSafetyLevelParam extends SafetyLevelParam(false) + object EnableSuperLikeSafetyLevelParam extends SafetyLevelParam(false) + object EnableTestSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineBookmarkSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineConversationsDownrankingSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineContentControlsSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineConversationsDownrankingMinimalSafetyLevelParam + extends SafetyLevelParam(false) + object EnableTimelineConversationsSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineFavoritesSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineFavoritesSelfViewSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineFocalTweetSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineFollowingActivitySafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineHomeCommunitiesSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineHomeHydrationSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineHomeLatestSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineHomePromotedHydrationSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineHomeRecommendationsSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineHomeSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineHomeTopicFollowRecommendationsSafetyLevelParam + extends SafetyLevelParam(false) + object EnableTimelineScorerSafetyLevelParam extends SafetyLevelParam(false) + object EnableTopicsLandingPageTopicRecommendationsSafetyLevelParam extends SafetyLevelParam(false) + object EnableExploreRecommendationsSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineInjectionSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineLikedBySafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineListsSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineMediaSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineMentionsSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineModeratedTweetsHydrationSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineProfileSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineProfileAllSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineProfileSpacesSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineProfileSuperFollowsSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineReactiveBlendingSafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineRetweetedBySafetyLevelParam extends SafetyLevelParam(false) + object EnableTimelineSuperLikedBySafetyLevelParam extends SafetyLevelParam(false) + object EnableTombstoningSafetyLevelParam extends SafetyLevelParam(false) + object EnableTopicRecommendationsSafetyLevelParam extends SafetyLevelParam(false) + object EnableTrendsRepresentativeTweetSafetyLevelParam extends SafetyLevelParam(false) + object EnableTrustedFriendsUserListSafetyLevelParam extends SafetyLevelParam(false) + object EnableTweetDetailSafetyLevelParam extends SafetyLevelParam(false) + object EnableTweetDetailNonTooSafetyLevelParam extends SafetyLevelParam(false) + object EnableTweetDetailWithInjectionsHydrationSafetyLevelParam extends SafetyLevelParam(false) + object EnableTweetEngagersSafetyLevelParam extends SafetyLevelParam(false) + object EnableTweetReplyNudgeParam extends SafetyLevelParam(false) + object EnableTweetScopedTimelineSafetyLevelParam extends SafetyLevelParam(false) + object EnableTweetWritesApiSafetyLevelParam extends SafetyLevelParam(false) + object EnableTwitterArticleComposeSafetyLevelParam extends SafetyLevelParam(false) + object EnableTwitterArticleProfileTabSafetyLevelParam extends SafetyLevelParam(false) + object EnableTwitterArticleReadSafetyLevelParam extends SafetyLevelParam(false) + object EnableUserProfileHeaderSafetyLevelParam extends SafetyLevelParam(false) + object EnableProfileMixerMediaSafetyLevelParam extends SafetyLevelParam(false) + object EnableProfileMixerFavoritesSafetyLevelParam extends SafetyLevelParam(false) + object EnableUserMilestoneRecommendationSafetyLevelParam extends SafetyLevelParam(false) + object EnableUserScopedTimelineSafetyLevelParam extends SafetyLevelParam(false) + object EnableUserSelfViewOnlySafetyLevelParam extends SafetyLevelParam(false) + object EnableUserSettingsSafetyLevelParam extends SafetyLevelParam(false) + object EnableVideoAdsSafetyLevelParam extends SafetyLevelParam(false) + object EnableZipbirdConsumerArchivesSafetyLevelParam extends SafetyLevelParam(false) + object EnableTweetAwardSafetyLevelParam extends SafetyLevelParam(false) + + object EnableDeprecatedSafetyLevel extends SafetyLevelParam(true) + object EnableQuotedTweetRulesParam extends SafetyLevelParam(true) + object EnableUnsupportedSafetyLevel extends SafetyLevelParam(true) + object EnableUnknownSafetyLevel$ extends SafetyLevelParam(true) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/TimelineConversationsDownrankingSpecificParams.scala b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/TimelineConversationsDownrankingSpecificParams.scala new file mode 100644 index 000000000..eacabbabd --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/TimelineConversationsDownrankingSpecificParams.scala @@ -0,0 +1,13 @@ +package com.twitter.visibility.configapi.params + +private[visibility] object TimelineConversationsDownrankingSpecificParams { + + object EnablePSpammyTweetDownrankConvosLowQualityParam extends RuleParam(false) + + object EnableRitoActionedTweetDownrankConvosLowQualityParam extends RuleParam(false) + + object EnableHighSpammyTweetContentScoreConvoDownrankAbusiveQualityRuleParam + extends RuleParam(false) + + object EnableHighCryptospamScoreConvoDownrankAbusiveQualityRuleParam extends RuleParam(false) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/VisibilityExperiment.scala b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/VisibilityExperiment.scala new file mode 100644 index 000000000..39af78a6b --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/VisibilityExperiment.scala @@ -0,0 +1,19 @@ +package com.twitter.visibility.configapi.params + +import com.twitter.timelines.configapi.BucketName +import com.twitter.timelines.configapi.Experiment +import com.twitter.timelines.configapi.UseFeatureContext + +object VisibilityExperiment { + val Control = "control" + val Treatment = "treatment" +} + +abstract class VisibilityExperiment(experimentKey: String) + extends Experiment(experimentKey) + with UseFeatureContext { + val TreatmentBucket: String = VisibilityExperiment.Treatment + override def experimentBuckets: Set[BucketName] = Set(TreatmentBucket) + val ControlBucket: String = VisibilityExperiment.Control + override def controlBuckets: Set[BucketName] = Set(ControlBucket) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/VisibilityExperiments.scala b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/VisibilityExperiments.scala new file mode 100644 index 000000000..b287783e6 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/configapi/params/VisibilityExperiments.scala @@ -0,0 +1,16 @@ +package com.twitter.visibility.configapi.params + +private[visibility] object VisibilityExperiments { + + case object TestExperiment extends VisibilityExperiment("vf_test_ddg_7727") + + object CommonBucketId extends Enumeration { + type CommonBucketId = Value + val Control = Value("control") + val Treatment = Value("treatment") + val None = Value("none") + } + + case object NotGraduatedUserLabelRuleExperiment + extends VisibilityExperiment("not_graduated_user_holdback_16332") +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/engine/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/engine/BUILD new file mode 100644 index 000000000..5e3503374 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/engine/BUILD @@ -0,0 +1,22 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "abdecider/src/main/scala", + "configapi/configapi-core", + "servo/util/src/main/scala", + "src/thrift/com/twitter/search/common:constants-scala", + "src/thrift/com/twitter/spam/rtf:safety-level-scala", + "src/thrift/com/twitter/tweetypie:tweet-scala", + "stitch/stitch-core", + "visibility/lib/src/main/scala/com/twitter/visibility/builder", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/params", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/models", + "visibility/lib/src/main/scala/com/twitter/visibility/rules", + "visibility/lib/src/main/scala/com/twitter/visibility/rules/providers", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/engine/DeciderableVisibilityRuleEngine.scala b/visibilitylib/src/main/scala/com/twitter/visibility/engine/DeciderableVisibilityRuleEngine.scala new file mode 100644 index 000000000..cb1119ce3 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/engine/DeciderableVisibilityRuleEngine.scala @@ -0,0 +1,26 @@ +package com.twitter.visibility.engine + +import com.twitter.servo.util.Gate +import com.twitter.spam.rtf.thriftscala.{SafetyLevel => ThriftSafetyLevel} +import com.twitter.stitch.Stitch +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.builder.VisibilityResultBuilder +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.rules.EvaluationContext +import com.twitter.visibility.rules.Rule + +trait DeciderableVisibilityRuleEngine { + def apply( + evaluationContext: EvaluationContext, + safetyLevel: SafetyLevel, + visibilityResultBuilder: VisibilityResultBuilder, + enableShortCircuiting: Gate[Unit] = Gate.True, + preprocessedRules: Option[Seq[Rule]] = None + ): Stitch[VisibilityResult] + + def apply( + evaluationContext: EvaluationContext, + thriftSafetyLevel: ThriftSafetyLevel, + visibilityResultBuilder: VisibilityResultBuilder + ): Stitch[VisibilityResult] +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/engine/VisibilityResultsMetricRecorder.scala b/visibilitylib/src/main/scala/com/twitter/visibility/engine/VisibilityResultsMetricRecorder.scala new file mode 100644 index 000000000..97af1d024 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/engine/VisibilityResultsMetricRecorder.scala @@ -0,0 +1,179 @@ +package com.twitter.visibility.engine + +import com.twitter.finagle.stats.NullStatsReceiver +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.finagle.stats.Verbosity +import com.twitter.servo.util.Gate +import com.twitter.servo.util.MemoizingStatsReceiver +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.features.Feature +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.rules.NotEvaluated +import com.twitter.visibility.rules.RuleResult +import com.twitter.visibility.rules.State +import com.twitter.visibility.rules.State.Disabled +import com.twitter.visibility.rules.State.FeatureFailed +import com.twitter.visibility.rules.State.MissingFeature +import com.twitter.visibility.rules.State.RuleFailed +import com.twitter.visibility.rules.Action + + +case class VisibilityResultsMetricRecorder( + statsReceiver: StatsReceiver, + captureDebugStats: Gate[Unit]) { + + private val scopedStatsReceiver = new MemoizingStatsReceiver( + statsReceiver.scope("visibility_rule_engine") + ) + private val actionStats: StatsReceiver = scopedStatsReceiver.scope("by_action") + private val featureFailureReceiver: StatsReceiver = + scopedStatsReceiver.scope("feature_failed") + private val safetyLevelStatsReceiver: StatsReceiver = + scopedStatsReceiver.scope("from_safety_level") + private val ruleStatsReceiver: StatsReceiver = scopedStatsReceiver.scope("for_rule") + private val ruleFailureReceiver: StatsReceiver = + scopedStatsReceiver.scope("rule_failures") + private val failClosedReceiver: StatsReceiver = + scopedStatsReceiver.scope("fail_closed") + private val ruleStatsBySafetyLevelReceiver: StatsReceiver = + scopedStatsReceiver.scope("for_rule_by_safety_level") + + def recordSuccess( + safetyLevel: SafetyLevel, + result: VisibilityResult + ): Unit = { + recordAction(safetyLevel, result.verdict.fullName) + + val isFeatureFailure = result.ruleResultMap.values + .collectFirst { + case RuleResult(_, FeatureFailed(_)) => + ruleFailureReceiver.counter("feature_failed").incr() + true + }.getOrElse(false) + + val isMissingFeature = result.ruleResultMap.values + .collectFirst { + case RuleResult(_, MissingFeature(_)) => + ruleFailureReceiver.counter("missing_feature").incr() + true + }.getOrElse(false) + + val isRuleFailed = result.ruleResultMap.values + .collectFirst { + case RuleResult(_, RuleFailed(_)) => + ruleFailureReceiver.counter("rule_failed").incr() + true + }.getOrElse(false) + + if (isFeatureFailure || isMissingFeature || isRuleFailed) { + ruleFailureReceiver.counter().incr() + } + + if (captureDebugStats()) { + val ruleBySafetyLevelStat = + ruleStatsBySafetyLevelReceiver.scope(safetyLevel.name) + result.ruleResultMap.foreach { + case (rule, ruleResult) => { + ruleBySafetyLevelStat + .scope(rule.name) + .scope("action") + .counter(Verbosity.Debug, ruleResult.action.fullName).incr() + ruleBySafetyLevelStat + .scope(rule.name) + .scope("state") + .counter(Verbosity.Debug, ruleResult.state.name).incr() + } + } + } + } + + def recordFailedFeature( + failedFeature: Feature[_], + exception: Throwable + ): Unit = { + featureFailureReceiver.counter().incr() + + val featureStat = featureFailureReceiver.scope(failedFeature.name) + featureStat.counter().incr() + featureStat.counter(exception.getClass.getName).incr() + } + + def recordAction( + safetyLevel: SafetyLevel, + action: String + ): Unit = { + safetyLevelStatsReceiver.scope(safetyLevel.name).counter(action).incr() + actionStats.counter(action).incr() + } + + def recordUnknownSafetyLevel( + safetyLevel: SafetyLevel + ): Unit = { + safetyLevelStatsReceiver + .scope("unknown_safety_level") + .counter(safetyLevel.name.toLowerCase).incr() + } + + def recordRuleMissingFeatures( + ruleName: String, + missingFeatures: Set[Feature[_]] + ): Unit = { + val ruleStat = ruleStatsReceiver.scope(ruleName) + missingFeatures.foreach { featureId => + ruleStat.scope("missing_feature").counter(featureId.name).incr() + } + ruleStat.scope("action").counter(NotEvaluated.fullName).incr() + ruleStat.scope("state").counter(MissingFeature(missingFeatures).name).incr() + } + + def recordRuleFailedFeatures( + ruleName: String, + failedFeatures: Map[Feature[_], Throwable] + ): Unit = { + val ruleStat = ruleStatsReceiver.scope(ruleName) + + ruleStat.scope("action").counter(NotEvaluated.fullName).incr() + ruleStat.scope("state").counter(FeatureFailed(failedFeatures).name).incr() + } + + def recordFailClosed(rule: String, state: State) { + failClosedReceiver.scope(state.name).counter(rule).incr(); + } + + def recordRuleEvaluation( + ruleName: String, + action: Action, + state: State + ): Unit = { + val ruleStat = ruleStatsReceiver.scope(ruleName) + ruleStat.scope("action").counter(action.fullName).incr() + ruleStat.scope("state").counter(state.name).incr() + } + + + def recordRuleFallbackAction( + ruleName: String + ): Unit = { + val ruleStat = ruleStatsReceiver.scope(ruleName) + ruleStat.counter("fallback_action").incr() + } + + def recordRuleHoldBack( + ruleName: String + ): Unit = { + ruleStatsReceiver.scope(ruleName).counter("heldback").incr() + } + + def recordRuleFailed( + ruleName: String + ): Unit = { + ruleStatsReceiver.scope(ruleName).counter("failed").incr() + } + + def recordDisabledRule( + ruleName: String + ): Unit = recordRuleEvaluation(ruleName, NotEvaluated, Disabled) +} + +object NullVisibilityResultsMetricsRecorder + extends VisibilityResultsMetricRecorder(NullStatsReceiver, Gate.False) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/engine/VisibilityRuleEngine.scala b/visibilitylib/src/main/scala/com/twitter/visibility/engine/VisibilityRuleEngine.scala new file mode 100644 index 000000000..6043f3649 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/engine/VisibilityRuleEngine.scala @@ -0,0 +1,266 @@ +package com.twitter.visibility.engine + +import com.twitter.servo.util.Gate +import com.twitter.spam.rtf.thriftscala.{SafetyLevel => ThriftSafetyLevel} +import com.twitter.stitch.Stitch +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.builder.VisibilityResultBuilder +import com.twitter.visibility.features._ +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.SafetyLevel.DeprecatedSafetyLevel +import com.twitter.visibility.rules.EvaluationContext +import com.twitter.visibility.rules.State._ +import com.twitter.visibility.rules._ +import com.twitter.visibility.rules.providers.ProvidedEvaluationContext +import com.twitter.visibility.rules.providers.PolicyProvider + +class VisibilityRuleEngine private[VisibilityRuleEngine] ( + rulePreprocessor: VisibilityRulePreprocessor, + metricsRecorder: VisibilityResultsMetricRecorder, + enableComposableActions: Gate[Unit], + enableFailClosed: Gate[Unit], + policyProviderOpt: Option[PolicyProvider] = None) + extends DeciderableVisibilityRuleEngine { + + private[visibility] def apply( + evaluationContext: ProvidedEvaluationContext, + visibilityPolicy: VisibilityPolicy, + visibilityResultBuilder: VisibilityResultBuilder, + enableShortCircuiting: Gate[Unit], + preprocessedRules: Option[Seq[Rule]] + ): Stitch[VisibilityResult] = { + val (resultBuilder, rules) = preprocessedRules match { + case Some(r) => + (visibilityResultBuilder, r) + case None => + rulePreprocessor.evaluate(evaluationContext, visibilityPolicy, visibilityResultBuilder) + } + evaluate(evaluationContext, resultBuilder, rules, enableShortCircuiting) + } + + def apply( + evaluationContext: EvaluationContext, + safetyLevel: SafetyLevel, + visibilityResultBuilder: VisibilityResultBuilder, + enableShortCircuiting: Gate[Unit] = Gate.True, + preprocessedRules: Option[Seq[Rule]] = None + ): Stitch[VisibilityResult] = { + val visibilityPolicy = policyProviderOpt match { + case Some(policyProvider) => + policyProvider.policyForSurface(safetyLevel) + case None => RuleBase.RuleMap(safetyLevel) + } + if (evaluationContext.params(safetyLevel.enabledParam)) { + apply( + ProvidedEvaluationContext.injectRuntimeRulesIntoEvaluationContext( + evaluationContext = evaluationContext, + safetyLevel = Some(safetyLevel), + policyProviderOpt = policyProviderOpt + ), + visibilityPolicy, + visibilityResultBuilder, + enableShortCircuiting, + preprocessedRules + ).onSuccess { result => + metricsRecorder.recordSuccess(safetyLevel, result) + } + .onFailure { _ => + metricsRecorder.recordAction(safetyLevel, "failure") + } + } else { + metricsRecorder.recordAction(safetyLevel, "disabled") + val rules: Seq[Rule] = visibilityPolicy.forContentId(visibilityResultBuilder.contentId) + Stitch.value( + visibilityResultBuilder + .withRuleResultMap(rules.map(r => r -> RuleResult(Allow, Skipped)).toMap) + .withVerdict(verdict = Allow) + .withFinished(finished = true) + .build + ) + } + } + + def apply( + evaluationContext: EvaluationContext, + thriftSafetyLevel: ThriftSafetyLevel, + visibilityResultBuilder: VisibilityResultBuilder + ): Stitch[VisibilityResult] = { + val safetyLevel: SafetyLevel = SafetyLevel.fromThrift(thriftSafetyLevel) + safetyLevel match { + case DeprecatedSafetyLevel => + metricsRecorder.recordUnknownSafetyLevel(safetyLevel) + Stitch.value( + visibilityResultBuilder + .withVerdict(verdict = Allow) + .withFinished(finished = true) + .build + ) + + case thriftSafetyLevel: SafetyLevel => + this( + ProvidedEvaluationContext.injectRuntimeRulesIntoEvaluationContext( + evaluationContext = evaluationContext, + safetyLevel = Some(safetyLevel), + policyProviderOpt = policyProviderOpt + ), + thriftSafetyLevel, + visibilityResultBuilder + ) + } + } + + private[visibility] def evaluateRules( + evaluationContext: ProvidedEvaluationContext, + resolvedFeatureMap: Map[Feature[_], Any], + failedFeatures: Map[Feature[_], Throwable], + resultBuilderWithoutFailedFeatures: VisibilityResultBuilder, + preprocessedRules: Seq[Rule], + enableShortCircuiting: Gate[Unit] + ): VisibilityResultBuilder = { + preprocessedRules + .foldLeft(resultBuilderWithoutFailedFeatures) { (builder, rule) => + builder.ruleResults.get(rule) match { + case Some(RuleResult(_, state)) if state == Evaluated || state == ShortCircuited => + builder + + case _ => + val failedFeatureDependencies: Map[Feature[_], Throwable] = + failedFeatures.filterKeys(key => rule.featureDependencies.contains(key)) + + val shortCircuit = + builder.finished && enableShortCircuiting() && + !(enableComposableActions() && builder.isVerdictComposable()) + + if (failedFeatureDependencies.nonEmpty && rule.fallbackActionBuilder.isEmpty) { + metricsRecorder.recordRuleFailedFeatures(rule.name, failedFeatureDependencies) + builder.withRuleResult( + rule, + RuleResult(NotEvaluated, FeatureFailed(failedFeatureDependencies))) + + } else if (shortCircuit) { + + metricsRecorder.recordRuleEvaluation(rule.name, NotEvaluated, ShortCircuited) + builder.withRuleResult(rule, RuleResult(builder.verdict, ShortCircuited)) + } else { + + if (rule.fallbackActionBuilder.nonEmpty) { + metricsRecorder.recordRuleFallbackAction(rule.name) + } + + + val ruleResult = + rule.evaluate(evaluationContext, resolvedFeatureMap) + metricsRecorder + .recordRuleEvaluation(rule.name, ruleResult.action, ruleResult.state) + val nextBuilder = (ruleResult.action, builder.finished) match { + case (NotEvaluated | Allow, _) => + ruleResult.state match { + case Heldback => + metricsRecorder.recordRuleHoldBack(rule.name) + case RuleFailed(_) => + metricsRecorder.recordRuleFailed(rule.name) + case _ => + } + builder.withRuleResult(rule, ruleResult) + + case (_, true) => + builder + .withRuleResult(rule, ruleResult) + .withSecondaryVerdict(ruleResult.action, rule) + + case _ => + builder + .withRuleResult(rule, ruleResult) + .withVerdict(ruleResult.action, Some(rule)) + .withFinished(true) + } + + nextBuilder + } + } + }.withResolvedFeatureMap(resolvedFeatureMap) + } + + private[visibility] def evaluateFailClosed( + evaluationContext: ProvidedEvaluationContext + ): VisibilityResultBuilder => Stitch[VisibilityResultBuilder] = { builder => + builder.failClosedException(evaluationContext) match { + case Some(e: FailClosedException) if enableFailClosed() => + metricsRecorder.recordFailClosed(e.getRuleName, e.getState); + Stitch.exception(e) + case _ => Stitch.value(builder) + } + } + + private[visibility] def checkMarkFinished( + builder: VisibilityResultBuilder + ): VisibilityResult = { + val allRulesEvaluated: Boolean = builder.ruleResults.values.forall { + case RuleResult(_, state) => + state == Evaluated || state == Disabled || state == Skipped + case _ => + false + } + + if (allRulesEvaluated) { + builder.withFinished(true).build + } else { + builder.build + } + } + + private[visibility] def evaluate( + evaluationContext: ProvidedEvaluationContext, + visibilityResultBuilder: VisibilityResultBuilder, + preprocessedRules: Seq[Rule], + enableShortCircuiting: Gate[Unit] = Gate.True + ): Stitch[VisibilityResult] = { + + val finalBuilder = + FeatureMap.resolve(visibilityResultBuilder.features, evaluationContext.statsReceiver).map { + resolvedFeatureMap => + val (failedFeatureMap, successfulFeatureMap) = resolvedFeatureMap.constantMap.partition({ + case (_, _: FeatureFailedPlaceholderObject) => true + case _ => false + }) + + val failedFeatures: Map[Feature[_], Throwable] = + failedFeatureMap.mapValues({ + case failurePlaceholder: FeatureFailedPlaceholderObject => + failurePlaceholder.throwable + }) + + val resultBuilderWithoutFailedFeatures = + visibilityResultBuilder.withFeatureMap(ResolvedFeatureMap(successfulFeatureMap)) + + evaluateRules( + evaluationContext, + successfulFeatureMap, + failedFeatures, + resultBuilderWithoutFailedFeatures, + preprocessedRules, + enableShortCircuiting + ) + } + + finalBuilder.flatMap(evaluateFailClosed(evaluationContext)).map(checkMarkFinished) + } +} + +object VisibilityRuleEngine { + + def apply( + rulePreprocessor: Option[VisibilityRulePreprocessor] = None, + metricsRecorder: VisibilityResultsMetricRecorder = NullVisibilityResultsMetricsRecorder, + enableComposableActions: Gate[Unit] = Gate.False, + enableFailClosed: Gate[Unit] = Gate.False, + policyProviderOpt: Option[PolicyProvider] = None, + ): VisibilityRuleEngine = { + new VisibilityRuleEngine( + rulePreprocessor.getOrElse(VisibilityRulePreprocessor(metricsRecorder)), + metricsRecorder, + enableComposableActions, + enableFailClosed, + policyProviderOpt = policyProviderOpt) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/engine/VisibilityRulePreprocessor.scala b/visibilitylib/src/main/scala/com/twitter/visibility/engine/VisibilityRulePreprocessor.scala new file mode 100644 index 000000000..115c37605 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/engine/VisibilityRulePreprocessor.scala @@ -0,0 +1,156 @@ +package com.twitter.visibility.engine + +import com.twitter.abdecider.NullABDecider +import com.twitter.util.Return +import com.twitter.util.Throw +import com.twitter.util.Try +import com.twitter.visibility.builder.VisibilityResultBuilder +import com.twitter.visibility.features._ +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.rules.Rule.DisabledRuleResult +import com.twitter.visibility.rules.Rule.EvaluatedRuleResult +import com.twitter.visibility.rules.State._ +import com.twitter.visibility.rules._ +import com.twitter.visibility.rules.providers.ProvidedEvaluationContext +import com.twitter.visibility.rules.providers.PolicyProvider + +class VisibilityRulePreprocessor private ( + metricsRecorder: VisibilityResultsMetricRecorder, + policyProviderOpt: Option[PolicyProvider] = None) { + + private[engine] def filterEvaluableRules( + evaluationContext: ProvidedEvaluationContext, + resultBuilder: VisibilityResultBuilder, + rules: Seq[Rule] + ): (VisibilityResultBuilder, Seq[Rule]) = { + val (builder, ruleList) = rules.foldLeft((resultBuilder, Seq.empty[Rule])) { + case ((builder, nextPassRules), rule) => + if (evaluationContext.ruleEnabledInContext(rule)) { + val missingFeatures: Set[Feature[_]] = rule.featureDependencies.collect { + case feature: Feature[_] if !builder.featureMap.contains(feature) => feature + } + + if (missingFeatures.isEmpty) { + (builder, nextPassRules :+ rule) + } else { + metricsRecorder.recordRuleMissingFeatures(rule.name, missingFeatures) + ( + builder.withRuleResult( + rule, + RuleResult(NotEvaluated, MissingFeature(missingFeatures)) + ), + nextPassRules + ) + } + } else { + (builder.withRuleResult(rule, DisabledRuleResult), nextPassRules) + } + } + (builder, ruleList) + } + + private[visibility] def preFilterRules( + evaluationContext: ProvidedEvaluationContext, + resolvedFeatureMap: Map[Feature[_], Any], + resultBuilder: VisibilityResultBuilder, + rules: Seq[Rule] + ): (VisibilityResultBuilder, Seq[Rule]) = { + val isResolvedFeatureMap = resultBuilder.featureMap.isInstanceOf[ResolvedFeatureMap] + val (builder, ruleList) = rules.foldLeft((resultBuilder, Seq.empty[Rule])) { + case ((builder, nextPassRules), rule) => + rule.preFilter(evaluationContext, resolvedFeatureMap, NullABDecider) match { + case NeedsFullEvaluation => + (builder, nextPassRules :+ rule) + case NotFiltered => + (builder, nextPassRules :+ rule) + case Filtered if isResolvedFeatureMap => + (builder, nextPassRules :+ rule) + case Filtered => + (builder.withRuleResult(rule, EvaluatedRuleResult), nextPassRules) + } + } + (builder, ruleList) + } + + private[visibility] def evaluate( + evaluationContext: ProvidedEvaluationContext, + safetyLevel: SafetyLevel, + resultBuilder: VisibilityResultBuilder + ): (VisibilityResultBuilder, Seq[Rule]) = { + val visibilityPolicy = policyProviderOpt match { + case Some(policyProvider) => + policyProvider.policyForSurface(safetyLevel) + case None => RuleBase.RuleMap(safetyLevel) + } + + if (evaluationContext.params(safetyLevel.enabledParam)) { + evaluate(evaluationContext, visibilityPolicy, resultBuilder) + } else { + metricsRecorder.recordAction(safetyLevel, "disabled") + + val rules: Seq[Rule] = visibilityPolicy.forContentId(resultBuilder.contentId) + val skippedResultBuilder = resultBuilder + .withRuleResultMap(rules.map(r => r -> RuleResult(Allow, Skipped)).toMap) + .withVerdict(verdict = Allow) + .withFinished(finished = true) + + (skippedResultBuilder, rules) + } + } + + private[visibility] def evaluate( + evaluationContext: ProvidedEvaluationContext, + visibilityPolicy: VisibilityPolicy, + resultBuilder: VisibilityResultBuilder, + ): (VisibilityResultBuilder, Seq[Rule]) = { + + val rules: Seq[Rule] = visibilityPolicy.forContentId(resultBuilder.contentId) + + val (secondPassBuilder, secondPassRules) = + filterEvaluableRules(evaluationContext, resultBuilder, rules) + + val secondPassFeatureMap = secondPassBuilder.featureMap + + val secondPassConstantFeatures: Set[Feature[_]] = RuleBase + .getFeaturesForRules(secondPassRules) + .filter(secondPassFeatureMap.containsConstant(_)) + + val secondPassFeatureValues: Set[(Feature[_], Any)] = secondPassConstantFeatures.map { + feature => + Try(secondPassFeatureMap.getConstant(feature)) match { + case Return(value) => (feature, value) + case Throw(ex) => + metricsRecorder.recordFailedFeature(feature, ex) + (feature, FeatureFailedPlaceholderObject(ex)) + } + } + + val resolvedFeatureMap: Map[Feature[_], Any] = + secondPassFeatureValues.filterNot { + case (_, value) => value.isInstanceOf[FeatureFailedPlaceholderObject] + }.toMap + + val (preFilteredResultBuilder, preFilteredRules) = preFilterRules( + evaluationContext, + resolvedFeatureMap, + secondPassBuilder, + secondPassRules + ) + + val preFilteredFeatureMap = + RuleBase.removeUnusedFeaturesFromFeatureMap( + preFilteredResultBuilder.featureMap, + preFilteredRules) + + (preFilteredResultBuilder.withFeatureMap(preFilteredFeatureMap), preFilteredRules) + } +} + +object VisibilityRulePreprocessor { + def apply( + metricsRecorder: VisibilityResultsMetricRecorder, + policyProviderOpt: Option[PolicyProvider] = None + ): VisibilityRulePreprocessor = { + new VisibilityRulePreprocessor(metricsRecorder, policyProviderOpt) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/features/AdvancedFilteringFeatures.scala b/visibilitylib/src/main/scala/com/twitter/visibility/features/AdvancedFilteringFeatures.scala new file mode 100644 index 000000000..4e6a33ba9 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/features/AdvancedFilteringFeatures.scala @@ -0,0 +1,24 @@ +package com.twitter.visibility.features + +import com.twitter.gizmoduck.thriftscala.MentionFilter +import com.twitter.util.Duration + +case object ViewerFiltersNoConfirmedEmail extends Feature[Boolean] + +case object ViewerFiltersNoConfirmedPhone extends Feature[Boolean] + +case object ViewerFiltersDefaultProfileImage extends Feature[Boolean] + +case object ViewerFiltersNewUsers extends Feature[Boolean] + +case object ViewerFiltersNotFollowedBy extends Feature[Boolean] + +case object ViewerMentionFilter extends Feature[MentionFilter] + +case object AuthorHasConfirmedEmail extends Feature[Boolean] + +case object AuthorHasVerifiedPhone extends Feature[Boolean] + +case object AuthorHasDefaultProfileImage extends Feature[Boolean] + +case object AuthorAccountAge extends Feature[Duration] diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/features/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/features/BUILD new file mode 100644 index 000000000..3573bb0db --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/features/BUILD @@ -0,0 +1,17 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "3rdparty/jvm/com/squareup/okhttp:okhttp3", + "finagle/finagle-mux/src/main/scala", + "src/thrift/com/twitter/gizmoduck:user-thrift-scala", + "src/thrift/com/twitter/search/common:constants-scala", + "src/thrift/com/twitter/tweetypie:tweet-scala", + "stitch/stitch-core", + "visibility/lib/src/main/scala/com/twitter/visibility/models", + "visibility/lib/src/main/scala/com/twitter/visibility/util", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/features/Feature.scala b/visibilitylib/src/main/scala/com/twitter/visibility/features/Feature.scala new file mode 100644 index 000000000..151718814 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/features/Feature.scala @@ -0,0 +1,11 @@ +package com.twitter.visibility.features + +import com.twitter.visibility.util.NamingUtils + +abstract class Feature[T] protected ()(implicit val manifest: Manifest[T]) { + + lazy val name: String = NamingUtils.getFriendlyName(this) + + override lazy val toString: String = + "Feature[%s](name=%s)".format(manifest, getClass.getSimpleName) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/features/FeatureMap.scala b/visibilitylib/src/main/scala/com/twitter/visibility/features/FeatureMap.scala new file mode 100644 index 000000000..1b4ffd182 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/features/FeatureMap.scala @@ -0,0 +1,121 @@ +package com.twitter.visibility.features + +import com.twitter.finagle.mux.ClientDiscardedRequestException +import com.twitter.finagle.stats.NullStatsReceiver +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.stitch.Stitch +import scala.language.existentials + +class MissingFeatureException(feature: Feature[_]) extends Exception("Missing value for " + feature) + +case class FeatureFailedException(feature: Feature[_], exception: Throwable) extends Exception + +private[visibility] case class FeatureFailedPlaceholderObject(throwable: Throwable) + +class FeatureMap( + val map: Map[Feature[_], Stitch[_]], + val constantMap: Map[Feature[_], Any]) { + + def contains[T](feature: Feature[T]): Boolean = + constantMap.contains(feature) || map.contains(feature) + + def containsConstant[T](feature: Feature[T]): Boolean = constantMap.contains(feature) + + lazy val size: Int = keys.size + + lazy val keys: Set[Feature[_]] = constantMap.keySet ++ map.keySet + + def get[T](feature: Feature[T]): Stitch[T] = { + map.get(feature) match { + case _ if constantMap.contains(feature) => + Stitch.value(getConstant(feature)) + case Some(x) => + x.asInstanceOf[Stitch[T]] + case _ => + Stitch.exception(new MissingFeatureException(feature)) + } + } + + def getConstant[T](feature: Feature[T]): T = { + constantMap.get(feature) match { + case Some(x) => + x.asInstanceOf[T] + case _ => + throw new MissingFeatureException(feature) + } + } + + def -[T](key: Feature[T]): FeatureMap = new FeatureMap(map - key, constantMap - key) + + override def toString: String = "FeatureMap(%s, %s)".format(map, constantMap) +} + +object FeatureMap { + + def empty: FeatureMap = new FeatureMap(Map.empty, Map.empty) + + def resolve( + featureMap: FeatureMap, + statsReceiver: StatsReceiver = NullStatsReceiver + ): Stitch[ResolvedFeatureMap] = { + val featureMapHydrationStatsReceiver = statsReceiver.scope("feature_map_hydration") + + Stitch + .traverse(featureMap.map.toSeq) { + case (feature, value: Stitch[_]) => + val featureStatsReceiver = featureMapHydrationStatsReceiver.scope(feature.name) + lazy val featureFailureStat = featureStatsReceiver.scope("failures") + val featureStitch: Stitch[(Feature[_], Any)] = value + .map { resolvedValue => + featureStatsReceiver.counter("success").incr() + (feature, resolvedValue) + } + + featureStitch + .handle { + case ffe: FeatureFailedException => + featureFailureStat.counter().incr() + featureFailureStat.counter(ffe.exception.getClass.getName).incr() + (feature, FeatureFailedPlaceholderObject(ffe.exception)) + } + .ensure { + featureStatsReceiver.counter("requests").incr() + } + } + .map { resolvedFeatures: Seq[(Feature[_], Any)] => + new ResolvedFeatureMap(resolvedFeatures.toMap ++ featureMap.constantMap) + } + } + + def rescueFeatureTuple(kv: (Feature[_], Stitch[_])): (Feature[_], Stitch[_]) = { + val (k, v) = kv + + val rescueValue = v.rescue { + case e => + e match { + case cdre: ClientDiscardedRequestException => Stitch.exception(cdre) + case _ => Stitch.exception(FeatureFailedException(k, e)) + } + } + + (k, rescueValue) + } +} + +class ResolvedFeatureMap(private[visibility] val resolvedMap: Map[Feature[_], Any]) + extends FeatureMap(Map.empty, resolvedMap) { + + override def equals(other: Any): Boolean = other match { + case otherResolvedFeatureMap: ResolvedFeatureMap => + this.resolvedMap.equals(otherResolvedFeatureMap.resolvedMap) + case _ => false + } + + override def toString: String = "ResolvedFeatureMap(%s)".format(resolvedMap) +} + +object ResolvedFeatureMap { + def apply(resolvedMap: Map[Feature[_], Any]): ResolvedFeatureMap = { + new ResolvedFeatureMap(resolvedMap) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/features/Features.scala b/visibilitylib/src/main/scala/com/twitter/visibility/features/Features.scala new file mode 100644 index 000000000..ae26dfe78 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/features/Features.scala @@ -0,0 +1,269 @@ +package com.twitter.visibility.features + +import com.twitter.contenthealth.toxicreplyfilter.thriftscala.FilterState +import com.twitter.gizmoduck.thriftscala.Label +import com.twitter.search.common.constants.thriftscala.ThriftQuerySource +import com.twitter.tseng.withholding.thriftscala.TakedownReason +import com.twitter.util.Duration +import com.twitter.util.Time +import com.twitter.visibility.models.TweetDeleteReason.TweetDeleteReason +import com.twitter.visibility.models._ + +case object AuthorId extends Feature[Set[Long]] + +case object ViewerId extends Feature[Long] + +case object AuthorIsProtected extends Feature[Boolean] + +case object AuthorIsSuspended extends Feature[Boolean] + +case object AuthorIsUnavailable extends Feature[Boolean] + +case object AuthorIsDeactivated extends Feature[Boolean] + +case object AuthorIsErased extends Feature[Boolean] + +case object AuthorIsOffboarded extends Feature[Boolean] + +case object AuthorIsVerified extends Feature[Boolean] + +case object AuthorIsBlueVerified extends Feature[Boolean] + +case object ViewerIsSuspended extends Feature[Boolean] + +case object ViewerIsDeactivated extends Feature[Boolean] + +case object AuthorFollowsViewer extends Feature[Boolean] + +case object AuthorUserLabels extends Feature[Seq[Label]] + +case object ViewerFollowsAuthorOfViolatingTweet extends Feature[Boolean] + +case object ViewerDoesNotFollowAuthorOfViolatingTweet extends Feature[Boolean] + +case object ViewerFollowsAuthor extends Feature[Boolean] + +case object ViewerBlocksAuthor extends Feature[Boolean] + +case object AuthorBlocksViewer extends Feature[Boolean] + +case object AuthorMutesViewer extends Feature[Boolean] + +case object ViewerMutesAuthor extends Feature[Boolean] + +case object AuthorReportsViewerAsSpam extends Feature[Boolean] + +case object ViewerReportsAuthorAsSpam extends Feature[Boolean] + +case object ViewerReportedTweet extends Feature[Boolean] + +case object ViewerMutesRetweetsFromAuthor extends Feature[Boolean] + +case object ViewerHasUniversalQualityFilterEnabled extends Feature[Boolean] + +case object ViewerIsProtected extends Feature[Boolean] + +case object ViewerIsSoftUser extends Feature[Boolean] + +case object TweetSafetyLabels extends Feature[Seq[TweetSafetyLabel]] + +case object SpaceSafetyLabels extends Feature[Seq[SpaceSafetyLabel]] + +case object MediaSafetyLabels extends Feature[Seq[MediaSafetyLabel]] + +case object TweetTakedownReasons extends Feature[Seq[TakedownReason]] + +case object AuthorTakedownReasons extends Feature[Seq[TakedownReason]] + +case object AuthorIsNsfwUser extends Feature[Boolean] + +case object AuthorIsNsfwAdmin extends Feature[Boolean] + +case object TweetHasNsfwUser extends Feature[Boolean] + +case object TweetHasNsfwAdmin extends Feature[Boolean] + +case object TweetHasMedia extends Feature[Boolean] + +case object CardHasMedia extends Feature[Boolean] + +case object TweetHasCard extends Feature[Boolean] + +case object ViewerMutesKeywordInTweetForHomeTimeline extends Feature[MutedKeyword] + +case object ViewerMutesKeywordInTweetForTweetReplies extends Feature[MutedKeyword] + +case object ViewerMutesKeywordInTweetForNotifications extends Feature[MutedKeyword] + +case object ViewerMutesKeywordInSpaceTitleForNotifications extends Feature[MutedKeyword] + +case object ViewerMutesKeywordInTweetForAllSurfaces extends Feature[MutedKeyword] + +case object ViewerUserLabels extends Feature[Seq[Label]] + +case object RequestCountryCode extends Feature[String] + +case object RequestIsVerifiedCrawler extends Feature[Boolean] + +case object ViewerCountryCode extends Feature[String] + +case object TweetIsSelfReply extends Feature[Boolean] + +case object TweetIsNullcast extends Feature[Boolean] + +case object TweetTimestamp extends Feature[Time] + +case object TweetIsInnerQuotedTweet extends Feature[Boolean] + +case object TweetIsRetweet extends Feature[Boolean] + +case object TweetIsSourceTweet extends Feature[Boolean] + +case object TweetDeleteReason extends Feature[TweetDeleteReason] + +case object TweetReplyToParentTweetDuration extends Feature[Duration] + +case object TweetReplyToRootTweetDuration extends Feature[Duration] + +case object TweetHasCommunityConversationControl extends Feature[Boolean] +case object TweetHasByInvitationConversationControl extends Feature[Boolean] +case object TweetHasFollowersConversationControl extends Feature[Boolean] +case object TweetConversationViewerIsInvited extends Feature[Boolean] +case object TweetConversationViewerIsInvitedViaReplyMention extends Feature[Boolean] +case object TweetConversationViewerIsRootAuthor extends Feature[Boolean] +case object ConversationRootAuthorFollowsViewer extends Feature[Boolean] +case object ViewerFollowsConversationRootAuthor extends Feature[Boolean] + +case object TweetIsExclusiveTweet extends Feature[Boolean] +case object ViewerIsExclusiveTweetRootAuthor extends Feature[Boolean] +case object ViewerSuperFollowsExclusiveTweetRootAuthor extends Feature[Boolean] + +case object TweetIsCommunityTweet extends Feature[Boolean] + +case object CommunityTweetCommunityNotFound extends Feature[Boolean] + +case object CommunityTweetCommunityDeleted extends Feature[Boolean] + +case object CommunityTweetCommunitySuspended extends Feature[Boolean] + +case object CommunityTweetCommunityVisible extends Feature[Boolean] + +case object CommunityTweetIsHidden extends Feature[Boolean] + +case object ViewerIsInternalCommunitiesAdmin extends Feature[Boolean] + +case object ViewerIsCommunityAdmin extends Feature[Boolean] + +case object ViewerIsCommunityModerator extends Feature[Boolean] + +case object ViewerIsCommunityMember extends Feature[Boolean] + +case object CommunityTweetAuthorIsRemoved extends Feature[Boolean] + +case object NotificationIsOnCommunityTweet extends Feature[Boolean] + +case object NotificationIsOnUnmentionedViewer extends Feature[Boolean] + +case object SearchResultsPageNumber extends Feature[Int] + +case object SearchCandidateCount extends Feature[Int] + +case object SearchQuerySource extends Feature[ThriftQuerySource] + +case object SearchQueryHasUser extends Feature[Boolean] + +case object TweetSemanticCoreAnnotations extends Feature[Seq[SemanticCoreAnnotation]] + +case object OuterAuthorId extends Feature[Long] + +case object AuthorBlocksOuterAuthor extends Feature[Boolean] + +case object OuterAuthorFollowsAuthor extends Feature[Boolean] + +case object OuterAuthorIsInnerAuthor extends Feature[Boolean] + +case object TweetIsModerated extends Feature[Boolean] +case object FocalTweetId extends Feature[Long] + +case object TweetId extends Feature[Long] + +case object TweetConversationId extends Feature[Long] +case object TweetParentId extends Feature[Long] +case object ConversationRootAuthorIsVerified extends Feature[Boolean] + +case object ViewerOptInBlocking extends Feature[Boolean] + +case object ViewerOptInFiltering extends Feature[Boolean] + +case object ViewerRoles extends Feature[Seq[String]] { + val EmployeeRole = "employee" +} + +case object TweetMisinformationPolicies extends Feature[Seq[MisinformationPolicy]] + +case object TweetEnglishMisinformationPolicies extends Feature[Seq[MisinformationPolicy]] + +case object HasInnerCircleOfFriendsRelationship extends Feature[Boolean] + +case object ViewerAge extends Feature[UserAge] + +case object HasDmcaMediaFeature extends Feature[Boolean] + +case object MediaGeoRestrictionsAllowList extends Feature[Seq[String]] +case object MediaGeoRestrictionsDenyList extends Feature[Seq[String]] + +case object TweetIsTrustedFriendTweet extends Feature[Boolean] +case object ViewerIsTrustedFriendTweetAuthor extends Feature[Boolean] +case object ViewerIsTrustedFriendOfTweetAuthor extends Feature[Boolean] + +case object DmConversationIsOneToOneConversation extends Feature[Boolean] +case object DmConversationHasEmptyTimeline extends Feature[Boolean] +case object DmConversationHasValidLastReadableEventId extends Feature[Boolean] +case object DmConversationInfoExists extends Feature[Boolean] +case object DmConversationTimelineExists extends Feature[Boolean] +case object ViewerIsDmConversationParticipant extends Feature[Boolean] + +case object DmEventIsMessageCreateEvent extends Feature[Boolean] +case object DmEventIsWelcomeMessageCreateEvent extends Feature[Boolean] +case object DmEventIsLastMessageReadUpdateEvent extends Feature[Boolean] +case object DmEventIsDeleted extends Feature[Boolean] +case object DmEventIsHidden extends Feature[Boolean] +case object ViewerIsDmEventInitiatingUser extends Feature[Boolean] +case object DmEventInOneToOneConversationWithUnavailableUser extends Feature[Boolean] +case object DmEventIsJoinConversationEvent extends Feature[Boolean] +case object DmEventIsConversationCreateEvent extends Feature[Boolean] +case object DmEventInOneToOneConversation extends Feature[Boolean] +case object DmEventIsTrustConversationEvent extends Feature[Boolean] +case object DmEventIsCsFeedbackSubmitted extends Feature[Boolean] +case object DmEventIsCsFeedbackDismissed extends Feature[Boolean] +case object DmEventIsPerspectivalJoinConversationEvent extends Feature[Boolean] + +case object DmEventOccurredBeforeLastClearedEvent extends Feature[Boolean] +case object DmEventOccurredBeforeJoinConversationEvent extends Feature[Boolean] + +case object CardUriHost extends Feature[String] +case object CardIsPoll extends Feature[Boolean] + +case object TweetIsStaleTweet extends Feature[Boolean] + +case object TweetIsEditTweet extends Feature[Boolean] + +case object TweetIsLatestTweet extends Feature[Boolean] + +case object TweetIsInitialTweet extends Feature[Boolean] + +case object TweetIsCollabInvitationTweet extends Feature[Boolean] + +case object ViewerSensitiveMediaSettings extends Feature[UserSensitiveMediaSettings] + + +case object ToxicReplyFilterState extends Feature[FilterState] + + +case object ToxicReplyFilterConversationAuthorIsViewer extends Feature[Boolean] + +case object RawQuery extends Feature[String] + +case object AuthorScreenName extends Feature[String] + +case object TweetIsInternalPromotedContent extends Feature[Boolean] diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/generators/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/generators/BUILD new file mode 100644 index 000000000..c53b6b59d --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/generators/BUILD @@ -0,0 +1,30 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "3rdparty/jvm/com/ibm/icu:icu4j", + "configapi/configapi-core", + "decider/src/main/scala", + "src/thrift/com/twitter/gizmoduck:thrift-scala", + "src/thrift/com/twitter/gizmoduck:user-thrift-scala", + "src/thrift/com/twitter/spam/rtf:safety-result-scala", + "stitch/stitch-core", + "strato/src/main/scala/com/twitter/strato/client", + "twitter-config/yaml", + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/common/src/main/scala/com/twitter/visibility/common/actions", + "visibility/common/src/main/scala/com/twitter/visibility/common/user_result", + "visibility/common/src/main/thrift/com/twitter/visibility:action-scala", + "visibility/lib/src/main/scala/com/twitter/visibility/builder", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/configs", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/models", + "visibility/lib/src/main/scala/com/twitter/visibility/rules", + "visibility/results/src/main/scala/com/twitter/visibility/results/richtext", + "visibility/results/src/main/scala/com/twitter/visibility/results/translation", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/generators/CountryNameGenerator.scala b/visibilitylib/src/main/scala/com/twitter/visibility/generators/CountryNameGenerator.scala new file mode 100644 index 000000000..014533a43 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/generators/CountryNameGenerator.scala @@ -0,0 +1,58 @@ +package com.twitter.visibility.generators + +import com.ibm.icu.util.ULocale +import com.twitter.config.yaml.YamlMap +import com.twitter.finagle.stats.StatsReceiver + +object CountryNameGenerator { + + private val AuroraFilesystemPath = "/usr/local/twitter-config/twitter/config/" + + private val ContentBlockingSupportedCountryList = "takedown_countries.yml" + + def providesFromConfigBus(statsReceiver: StatsReceiver): CountryNameGenerator = { + fromFile(AuroraFilesystemPath + ContentBlockingSupportedCountryList, statsReceiver) + } + + def providesWithCustomMap(countryCodeMap: Map[String, String], statsReceiver: StatsReceiver) = { + new CountryNameGenerator(countryCodeMap, statsReceiver) + } + + private def fromFile(fileName: String, statsReceiver: StatsReceiver) = { + val yamlConfig = YamlMap.load(fileName) + val countryCodeMap: Map[String, String] = yamlConfig.keySet.map { countryCode: String => + val normalizedCode = countryCode.toUpperCase + val countryName: Option[String] = + yamlConfig.get(Seq(countryCode, "name")).asInstanceOf[Option[String]] + (normalizedCode, countryName.getOrElse(normalizedCode)) + }.toMap + new CountryNameGenerator(countryCodeMap, statsReceiver) + } +} + +class CountryNameGenerator(countryCodeMap: Map[String, String], statsReceiver: StatsReceiver) { + + private val scopedStatsReceiver = statsReceiver.scope("country_name_generator") + private val foundCountryReceiver = scopedStatsReceiver.counter("found") + private val missingCountryReceiver = scopedStatsReceiver.counter("missing") + + def getCountryName(code: String): String = { + val normalizedCode = code.toUpperCase + countryCodeMap.get(normalizedCode) match { + case Some(retrievedName) => { + foundCountryReceiver.incr() + retrievedName + } + case _ => { + missingCountryReceiver.incr() + val fallbackName = + new ULocale("", normalizedCode).getDisplayCountry(ULocale.forLanguageTag("en")) + + if (fallbackName == "") + normalizedCode + else + fallbackName + } + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/generators/EpitaphToLocalizedMessage.scala b/visibilitylib/src/main/scala/com/twitter/visibility/generators/EpitaphToLocalizedMessage.scala new file mode 100644 index 000000000..af5266848 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/generators/EpitaphToLocalizedMessage.scala @@ -0,0 +1,66 @@ +package com.twitter.visibility.generators + +import com.twitter.visibility.common.actions.LocalizedMessage +import com.twitter.visibility.common.actions.MessageLink +import com.twitter.visibility.results.translation.Translator +import com.twitter.visibility.results.richtext.EpitaphToRichText +import com.twitter.visibility.results.translation.Resource +import com.twitter.visibility.results.translation.LearnMoreLink +import com.twitter.visibility.rules.Epitaph +import com.twitter.visibility.results.richtext.EpitaphToRichText.Copy + +object EpitaphToLocalizedMessage { + def apply( + epitaph: Epitaph, + languageTag: String, + ): LocalizedMessage = { + val copy = + EpitaphToRichText.EpitaphToPolicyMap.getOrElse(epitaph, EpitaphToRichText.FallbackPolicy) + val text = Translator.translate( + copy.resource, + languageTag + ) + localizeWithCopyAndText(copy, languageTag, text) + } + + def apply( + epitaph: Epitaph, + languageTag: String, + applicableCountries: Seq[String], + ): LocalizedMessage = { + val copy = + EpitaphToRichText.EpitaphToPolicyMap.getOrElse(epitaph, EpitaphToRichText.FallbackPolicy) + val text = Translator.translateWithSimplePlaceholderReplacement( + copy.resource, + languageTag, + Map((Resource.ApplicableCountriesPlaceholder -> applicableCountries.mkString(", "))) + ) + localizeWithCopyAndText(copy, languageTag, text) + } + + private def localizeWithCopyAndText( + copy: Copy, + languageTag: String, + text: String + ): LocalizedMessage = { + val learnMore = Translator.translate(LearnMoreLink, languageTag) + + val links = copy.additionalLinks match { + case links if links.nonEmpty => + MessageLink(Resource.LearnMorePlaceholder, learnMore, copy.link) +: + links.map { + case EpitaphToRichText.Link(placeholder, copyResource, link) => + val copyText = Translator.translate(copyResource, languageTag) + MessageLink(placeholder, copyText, link) + } + case _ => + Seq( + MessageLink( + key = Resource.LearnMorePlaceholder, + displayText = learnMore, + uri = copy.link)) + } + + LocalizedMessage(message = text, language = languageTag, links = links) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/generators/InterstitialReasonToLocalizedMessage.scala b/visibilitylib/src/main/scala/com/twitter/visibility/generators/InterstitialReasonToLocalizedMessage.scala new file mode 100644 index 000000000..f4e000338 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/generators/InterstitialReasonToLocalizedMessage.scala @@ -0,0 +1,47 @@ +package com.twitter.visibility.generators + +import com.twitter.visibility.common.actions.InterstitialReason +import com.twitter.visibility.common.actions.LocalizedMessage +import com.twitter.visibility.common.actions.MessageLink +import com.twitter.visibility.results.richtext.InterstitialReasonToRichText +import com.twitter.visibility.results.richtext.InterstitialReasonToRichText.InterstitialCopy +import com.twitter.visibility.results.richtext.InterstitialReasonToRichText.InterstitialLink +import com.twitter.visibility.results.translation.LearnMoreLink +import com.twitter.visibility.results.translation.Resource +import com.twitter.visibility.results.translation.Translator + +object InterstitialReasonToLocalizedMessage { + def apply( + reason: InterstitialReason, + languageTag: String, + ): Option[LocalizedMessage] = { + InterstitialReasonToRichText.reasonToCopy(reason).map { copy => + val text = Translator.translate( + copy.resource, + languageTag + ) + localizeWithCopyAndText(copy, languageTag, text) + } + } + + private def localizeWithCopyAndText( + copy: InterstitialCopy, + languageTag: String, + text: String + ): LocalizedMessage = { + val learnMore = Translator.translate(LearnMoreLink, languageTag) + + val learnMoreLinkOpt = + copy.link.map { link => + MessageLink(key = Resource.LearnMorePlaceholder, displayText = learnMore, uri = link) + } + val additionalLinks = copy.additionalLinks.map { + case InterstitialLink(placeholder, copyResource, link) => + val copyText = Translator.translate(copyResource, languageTag) + MessageLink(key = placeholder, displayText = copyText, uri = link) + } + + val links = learnMoreLinkOpt.toSeq ++ additionalLinks + LocalizedMessage(message = text, language = languageTag, links = links) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/generators/LocalizedInterstitialGenerator.scala b/visibilitylib/src/main/scala/com/twitter/visibility/generators/LocalizedInterstitialGenerator.scala new file mode 100644 index 000000000..6d381642e --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/generators/LocalizedInterstitialGenerator.scala @@ -0,0 +1,151 @@ +package com.twitter.visibility.generators + +import com.twitter.decider.Decider +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.common.actions.LocalizedMessage +import com.twitter.visibility.common.actions.MessageLink +import com.twitter.visibility.configapi.configs.VisibilityDeciderGates +import com.twitter.visibility.results.richtext.PublicInterestReasonToRichText +import com.twitter.visibility.results.translation.LearnMoreLink +import com.twitter.visibility.results.translation.Resource +import com.twitter.visibility.results.translation.SafetyResultReasonToResource +import com.twitter.visibility.results.translation.Translator +import com.twitter.visibility.rules.EmergencyDynamicInterstitial +import com.twitter.visibility.rules.Interstitial +import com.twitter.visibility.rules.InterstitialLimitedEngagements +import com.twitter.visibility.rules.PublicInterest +import com.twitter.visibility.rules.Reason +import com.twitter.visibility.rules.TweetInterstitial + +object LocalizedInterstitialGenerator { + def apply( + visibilityDecider: Decider, + baseStatsReceiver: StatsReceiver, + ): LocalizedInterstitialGenerator = { + new LocalizedInterstitialGenerator(visibilityDecider, baseStatsReceiver) + } +} + +class LocalizedInterstitialGenerator private ( + val visibilityDecider: Decider, + val baseStatsReceiver: StatsReceiver) { + + private val visibilityDeciderGates = VisibilityDeciderGates(visibilityDecider) + private val localizationStatsReceiver = baseStatsReceiver.scope("interstitial_localization") + private val publicInterestInterstitialStats = + localizationStatsReceiver.scope("public_interest_copy") + private val emergencyDynamicInterstitialStats = + localizationStatsReceiver.scope("emergency_dynamic_copy") + private val regularInterstitialStats = localizationStatsReceiver.scope("interstitial_copy") + + def apply(visibilityResult: VisibilityResult, languageTag: String): VisibilityResult = { + if (!visibilityDeciderGates.enableLocalizedInterstitialGenerator()) { + return visibilityResult + } + + visibilityResult.verdict match { + case ipi: InterstitialLimitedEngagements if PublicInterest.Reasons.contains(ipi.reason) => + visibilityResult.copy( + verdict = ipi.copy( + localizedMessage = Some(localizePublicInterestCopyInResult(ipi, languageTag)) + )) + case edi: EmergencyDynamicInterstitial => + visibilityResult.copy( + verdict = EmergencyDynamicInterstitial( + edi.copy, + edi.linkOpt, + Some(localizeEmergencyDynamicCopyInResult(edi, languageTag)) + )) + case interstitial: Interstitial => + visibilityResult.copy( + verdict = interstitial.copy( + localizedMessage = localizeInterstitialCopyInResult(interstitial, languageTag) + )) + case tweetInterstitial: TweetInterstitial if tweetInterstitial.interstitial.isDefined => + tweetInterstitial.interstitial.get match { + case ipi: InterstitialLimitedEngagements if PublicInterest.Reasons.contains(ipi.reason) => + visibilityResult.copy( + verdict = tweetInterstitial.copy( + interstitial = Some( + ipi.copy( + localizedMessage = Some(localizePublicInterestCopyInResult(ipi, languageTag)) + )) + )) + case edi: EmergencyDynamicInterstitial => + visibilityResult.copy( + verdict = tweetInterstitial.copy( + interstitial = Some( + EmergencyDynamicInterstitial( + edi.copy, + edi.linkOpt, + Some(localizeEmergencyDynamicCopyInResult(edi, languageTag)) + )) + )) + case interstitial: Interstitial => + visibilityResult.copy( + verdict = tweetInterstitial.copy( + interstitial = Some( + interstitial.copy( + localizedMessage = localizeInterstitialCopyInResult(interstitial, languageTag) + )) + )) + case _ => visibilityResult + } + case _ => visibilityResult + } + } + + private def localizeEmergencyDynamicCopyInResult( + edi: EmergencyDynamicInterstitial, + languageTag: String + ): LocalizedMessage = { + val text = edi.linkOpt + .map(_ => s"${edi.copy} {${Resource.LearnMorePlaceholder}}") + .getOrElse(edi.copy) + + val messageLinks = edi.linkOpt + .map { link => + val learnMoreText = Translator.translate(LearnMoreLink, languageTag) + Seq(MessageLink(Resource.LearnMorePlaceholder, learnMoreText, link)) + }.getOrElse(Seq.empty) + + emergencyDynamicInterstitialStats.counter("localized").incr() + LocalizedMessage(text, languageTag, messageLinks) + } + + private def localizePublicInterestCopyInResult( + ipi: InterstitialLimitedEngagements, + languageTag: String + ): LocalizedMessage = { + val safetyResultReason = PublicInterest.ReasonToSafetyResultReason(ipi.reason) + val text = Translator.translate( + SafetyResultReasonToResource.resource(safetyResultReason), + languageTag, + ) + + val learnMoreLink = PublicInterestReasonToRichText.toLearnMoreLink(safetyResultReason) + val learnMoreText = Translator.translate(LearnMoreLink, languageTag) + val messageLinks = Seq(MessageLink(Resource.LearnMorePlaceholder, learnMoreText, learnMoreLink)) + + publicInterestInterstitialStats.counter("localized").incr() + LocalizedMessage(text, languageTag, messageLinks) + } + + private def localizeInterstitialCopyInResult( + interstitial: Interstitial, + languageTag: String + ): Option[LocalizedMessage] = { + val localizedMessageOpt = Reason + .toInterstitialReason(interstitial.reason) + .flatMap(InterstitialReasonToLocalizedMessage(_, languageTag)) + + if (localizedMessageOpt.isDefined) { + regularInterstitialStats.counter("localized").incr() + localizedMessageOpt + } else { + regularInterstitialStats.counter("empty").incr() + None + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/generators/TombstoneGenerator.scala b/visibilitylib/src/main/scala/com/twitter/visibility/generators/TombstoneGenerator.scala new file mode 100644 index 000000000..9d52cc217 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/generators/TombstoneGenerator.scala @@ -0,0 +1,94 @@ +package com.twitter.visibility.generators + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.servo.util.MemoizingStatsReceiver +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.common.actions.TombstoneReason +import com.twitter.visibility.configapi.VisibilityParams +import com.twitter.visibility.rules.Epitaph +import com.twitter.visibility.rules.LocalizedTombstone +import com.twitter.visibility.rules.Tombstone + +object TombstoneGenerator { + def apply( + visibilityParams: VisibilityParams, + countryNameGenerator: CountryNameGenerator, + statsReceiver: StatsReceiver + ): TombstoneGenerator = { + new TombstoneGenerator(visibilityParams, countryNameGenerator, statsReceiver) + } +} + +class TombstoneGenerator( + paramsFactory: VisibilityParams, + countryNameGenerator: CountryNameGenerator, + baseStatsReceiver: StatsReceiver) { + + private[this] val statsReceiver = new MemoizingStatsReceiver( + baseStatsReceiver.scope("tombstone_generator")) + private[this] val deletedReceiver = statsReceiver.scope("deleted_state") + private[this] val authorStateReceiver = statsReceiver.scope("tweet_author_state") + private[this] val visResultReceiver = statsReceiver.scope("visibility_result") + + def apply( + result: VisibilityResult, + language: String + ): VisibilityResult = { + + result.verdict match { + case tombstone: Tombstone => + val epitaph = tombstone.epitaph + visResultReceiver.scope("tombstone").counter(epitaph.name.toLowerCase()) + + val overriddenLanguage = epitaph match { + case Epitaph.LegalDemandsWithheldMedia | Epitaph.LocalLawsWithheldMedia => "en" + case _ => language + } + + tombstone.applicableCountryCodes match { + case Some(countryCodes) => { + val countryNames = countryCodes.map(countryNameGenerator.getCountryName(_)) + + result.copy(verdict = LocalizedTombstone( + reason = epitaphToTombstoneReason(epitaph), + message = EpitaphToLocalizedMessage(epitaph, overriddenLanguage, countryNames))) + } + case _ => { + result.copy(verdict = LocalizedTombstone( + reason = epitaphToTombstoneReason(epitaph), + message = EpitaphToLocalizedMessage(epitaph, overriddenLanguage))) + } + } + case _ => + result + } + } + + private def epitaphToTombstoneReason(epitaph: Epitaph): TombstoneReason = { + epitaph match { + case Epitaph.Deleted => TombstoneReason.Deleted + case Epitaph.Bounced => TombstoneReason.Bounced + case Epitaph.BounceDeleted => TombstoneReason.BounceDeleted + case Epitaph.Protected => TombstoneReason.ProtectedAuthor + case Epitaph.Suspended => TombstoneReason.SuspendedAuthor + case Epitaph.BlockedBy => TombstoneReason.AuthorBlocksViewer + case Epitaph.SuperFollowsContent => TombstoneReason.ExclusiveTweet + case Epitaph.Underage => TombstoneReason.NsfwViewerIsUnderage + case Epitaph.NoStatedAge => TombstoneReason.NsfwViewerHasNoStatedAge + case Epitaph.LoggedOutAge => TombstoneReason.NsfwLoggedOut + case Epitaph.Deactivated => TombstoneReason.DeactivatedAuthor + case Epitaph.CommunityTweetHidden => TombstoneReason.CommunityTweetHidden + case Epitaph.CommunityTweetCommunityIsSuspended => + TombstoneReason.CommunityTweetCommunityIsSuspended + case Epitaph.DevelopmentOnly => TombstoneReason.DevelopmentOnly + case Epitaph.AdultMedia => TombstoneReason.AdultMedia + case Epitaph.ViolentMedia => TombstoneReason.ViolentMedia + case Epitaph.OtherSensitiveMedia => TombstoneReason.OtherSensitiveMedia + case Epitaph.DmcaWithheldMedia => TombstoneReason.DmcaWithheldMedia + case Epitaph.LegalDemandsWithheldMedia => TombstoneReason.LegalDemandsWithheldMedia + case Epitaph.LocalLawsWithheldMedia => TombstoneReason.LocalLawsWithheldMedia + case Epitaph.ToxicReplyFiltered => TombstoneReason.ReplyFiltered + case _ => TombstoneReason.Unspecified + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/blender/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/blender/BUILD new file mode 100644 index 000000000..545aee9e5 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/blender/BUILD @@ -0,0 +1,34 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "3rdparty/jvm/com/twitter/src/java/com/twitter/logpipeline/client:logpipeline-event-publisher-thin", + "decider/src/main/scala", + "mediaservices/media-util/src/main/scala", + "servo/decider/src/main/scala", + "src/thrift/com/twitter/escherbird:media-annotation-structs-scala", + "src/thrift/com/twitter/spam/rtf:safety-level-scala", + "src/thrift/com/twitter/spam/rtf:tweet-rtf-event-scala", + "src/thrift/com/twitter/tweetypie:tweet-scala", + "stitch/stitch-core", + "strato/src/main/scala/com/twitter/strato/client", + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/lib/src/main/resources/config", + "visibility/lib/src/main/scala/com/twitter/visibility", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/media", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/tweets", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/users", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/configs", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/common/blender", + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/common/tweets", + "visibility/lib/src/main/thrift/com/twitter/visibility/logging:vf-logging-scala", + ], + exports = [ + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/lib/src/main/scala/com/twitter/visibility", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/blender/BlenderVisibilityLibrary.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/blender/BlenderVisibilityLibrary.scala new file mode 100644 index 000000000..c83756818 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/blender/BlenderVisibilityLibrary.scala @@ -0,0 +1,416 @@ +package com.twitter.visibility.interfaces.blender + +import com.twitter.decider.Decider +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.mediaservices.media_util.GenericMediaKey +import com.twitter.servo.util.Gate +import com.twitter.stitch.Stitch +import com.twitter.strato.client.{Client => StratoClient} +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.util.Stopwatch +import com.twitter.visibility.VisibilityLibrary +import com.twitter.visibility.builder.VerdictLogger +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.builder.media.MediaFeatures +import com.twitter.visibility.builder.media.StratoMediaLabelMaps +import com.twitter.visibility.builder.tweets._ +import com.twitter.visibility.builder.users.AuthorFeatures +import com.twitter.visibility.builder.users.RelationshipFeatures +import com.twitter.visibility.builder.users.ViewerFeatures +import com.twitter.visibility.common.MediaSafetyLabelMapSource +import com.twitter.visibility.common.MisinformationPolicySource +import com.twitter.visibility.common.SafetyLabelMapSource +import com.twitter.visibility.common.TrustedFriendsSource +import com.twitter.visibility.common.UserRelationshipSource +import com.twitter.visibility.common.UserSource +import com.twitter.visibility.rules.ComposableActions.ComposableActionsWithInterstitial +import com.twitter.visibility.configapi.configs.VisibilityDeciderGates +import com.twitter.visibility.features.FeatureMap +import com.twitter.visibility.features.TweetIsInnerQuotedTweet +import com.twitter.visibility.features.TweetIsRetweet +import com.twitter.visibility.features.TweetIsSourceTweet +import com.twitter.visibility.logging.thriftscala.VFLibType +import com.twitter.visibility.models.ContentId +import com.twitter.visibility.models.ContentId.BlenderTweetId +import com.twitter.visibility.models.ContentId.TweetId +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.SafetyLevel.toThrift +import com.twitter.visibility.rules.Action +import com.twitter.visibility.rules.Allow +import com.twitter.visibility.rules.Drop +import com.twitter.visibility.rules.Interstitial +import com.twitter.visibility.rules.TweetInterstitial + +object TweetType extends Enumeration { + type TweetType = Value + val ORIGINAL, SOURCE, QUOTED = Value +} +import com.twitter.visibility.interfaces.blender.TweetType._ + +object BlenderVisibilityLibrary { + def buildWithStratoClient( + visibilityLibrary: VisibilityLibrary, + decider: Decider, + stratoClient: StratoClient, + userSource: UserSource, + userRelationshipSource: UserRelationshipSource + ): BlenderVisibilityLibrary = new BlenderVisibilityLibrary( + visibilityLibrary, + decider, + stratoClient, + userSource, + userRelationshipSource, + None + ) + + def buildWithSafetyLabelMapSource( + visibilityLibrary: VisibilityLibrary, + decider: Decider, + stratoClient: StratoClient, + userSource: UserSource, + userRelationshipSource: UserRelationshipSource, + safetyLabelMapSource: SafetyLabelMapSource + ): BlenderVisibilityLibrary = new BlenderVisibilityLibrary( + visibilityLibrary, + decider, + stratoClient, + userSource, + userRelationshipSource, + Some(safetyLabelMapSource) + ) + + def createVerdictLogger( + enableVerdictLogger: Gate[Unit], + decider: Decider, + statsReceiver: StatsReceiver + ): VerdictLogger = { + if (enableVerdictLogger()) { + VerdictLogger(statsReceiver, decider) + } else { + VerdictLogger.Empty + } + } + + def scribeVisibilityVerdict( + result: CombinedVisibilityResult, + enableVerdictScribing: Gate[Unit], + verdictLogger: VerdictLogger, + viewerId: Option[Long], + safetyLevel: SafetyLevel + ): Unit = if (enableVerdictScribing()) { + verdictLogger.scribeVerdict( + visibilityResult = result.tweetVisibilityResult, + viewerId = viewerId, + safetyLevel = toThrift(safetyLevel), + vfLibType = VFLibType.BlenderVisibilityLibrary) + + result.quotedTweetVisibilityResult.map(quotedTweetVisibilityResult => + verdictLogger.scribeVerdict( + visibilityResult = quotedTweetVisibilityResult, + viewerId = viewerId, + safetyLevel = toThrift(safetyLevel), + vfLibType = VFLibType.BlenderVisibilityLibrary)) + } +} + +class BlenderVisibilityLibrary( + visibilityLibrary: VisibilityLibrary, + decider: Decider, + stratoClient: StratoClient, + userSource: UserSource, + userRelationshipSource: UserRelationshipSource, + safetyLabelMapSourceOption: Option[SafetyLabelMapSource]) { + + val libraryStatsReceiver = visibilityLibrary.statsReceiver + val stratoClientStatsReceiver = visibilityLibrary.statsReceiver.scope("strato") + val vfEngineCounter = libraryStatsReceiver.counter("vf_engine_requests") + val bvlRequestCounter = libraryStatsReceiver.counter("bvl_requests") + val vfLatencyOverallStat = libraryStatsReceiver.stat("vf_latency_overall") + val vfLatencyStitchBuildStat = libraryStatsReceiver.stat("vf_latency_stitch_build") + val vfLatencyStitchRunStat = libraryStatsReceiver.stat("vf_latency_stitch_run") + val visibilityDeciderGates = VisibilityDeciderGates(decider) + val verdictLogger = BlenderVisibilityLibrary.createVerdictLogger( + visibilityDeciderGates.enableVerdictLoggerBVL, + decider, + libraryStatsReceiver) + + val tweetLabels = safetyLabelMapSourceOption match { + case Some(safetyLabelMapSource) => + new StratoTweetLabelMaps(safetyLabelMapSource) + case None => + new StratoTweetLabelMaps( + SafetyLabelMapSource.fromStrato(stratoClient, stratoClientStatsReceiver)) + } + + val mediaLabelMaps = new StratoMediaLabelMaps( + MediaSafetyLabelMapSource.fromStrato(stratoClient, stratoClientStatsReceiver)) + + val tweetFeatures = new TweetFeatures(tweetLabels, libraryStatsReceiver) + val blenderContextFeatures = new BlenderContextFeatures(libraryStatsReceiver) + val authorFeatures = new AuthorFeatures(userSource, libraryStatsReceiver) + val viewerFeatures = new ViewerFeatures(userSource, libraryStatsReceiver) + val relationshipFeatures = + new RelationshipFeatures(userRelationshipSource, libraryStatsReceiver) + val fonsrRelationshipFeatures = + new FosnrRelationshipFeatures( + tweetLabels = tweetLabels, + userRelationshipSource = userRelationshipSource, + statsReceiver = libraryStatsReceiver) + val misinfoPolicySource = + MisinformationPolicySource.fromStrato(stratoClient, stratoClientStatsReceiver) + val misinfoPolicyFeatures = + new MisinformationPolicyFeatures(misinfoPolicySource, stratoClientStatsReceiver) + val exclusiveTweetFeatures = + new ExclusiveTweetFeatures(userRelationshipSource, libraryStatsReceiver) + val mediaFeatures = new MediaFeatures(mediaLabelMaps, libraryStatsReceiver) + val trustedFriendsTweetFeatures = new TrustedFriendsFeatures( + trustedFriendsSource = TrustedFriendsSource.fromStrato(stratoClient, stratoClientStatsReceiver)) + val editTweetFeatures = new EditTweetFeatures(libraryStatsReceiver) + + def getCombinedVisibilityResult( + bvRequest: BlenderVisibilityRequest + ): Stitch[CombinedVisibilityResult] = { + val elapsed = Stopwatch.start() + bvlRequestCounter.incr() + + val ( + requestTweetVisibilityResult, + quotedTweetVisibilityResultOption, + sourceTweetVisibilityResultOption + ) = getAllVisibilityResults(bvRequest: BlenderVisibilityRequest) + + val response: Stitch[CombinedVisibilityResult] = { + ( + requestTweetVisibilityResult, + quotedTweetVisibilityResultOption, + sourceTweetVisibilityResultOption) match { + case (requestTweetVisResult, Some(quotedTweetVisResult), Some(sourceTweetVisResult)) => { + Stitch + .join( + requestTweetVisResult, + quotedTweetVisResult, + sourceTweetVisResult + ).map { + case (requestTweetVisResult, quotedTweetVisResult, sourceTweetVisResult) => { + requestTweetVisResult.verdict match { + case Allow => + CombinedVisibilityResult(sourceTweetVisResult, Some(quotedTweetVisResult)) + case _ => + CombinedVisibilityResult(requestTweetVisResult, Some(quotedTweetVisResult)) + } + } + } + } + + case (requestTweetVisResult, None, Some(sourceTweetVisResult)) => { + Stitch + .join( + requestTweetVisResult, + sourceTweetVisResult + ).map { + case (requestTweetVisResult, sourceTweetVisResult) => { + requestTweetVisResult.verdict match { + case Allow => + CombinedVisibilityResult(sourceTweetVisResult, None) + case _ => + CombinedVisibilityResult(requestTweetVisResult, None) + } + } + } + } + + case (requestTweetVisResult, Some(quotedTweetVisResult), None) => { + Stitch + .join( + requestTweetVisResult, + quotedTweetVisResult + ).map { + case (requestTweetVisResult, quotedTweetVisResult) => { + CombinedVisibilityResult(requestTweetVisResult, Some(quotedTweetVisResult)) + } + } + } + + case (requestTweetVisResult, None, None) => { + requestTweetVisResult.map { + CombinedVisibilityResult(_, None) + } + } + } + } + val runStitchStartMs = elapsed().inMilliseconds + val buildStitchStatMs = elapsed().inMilliseconds + vfLatencyStitchBuildStat.add(buildStitchStatMs) + + response + .onSuccess(_ => { + val overallMs = elapsed().inMilliseconds + vfLatencyOverallStat.add(overallMs) + val stitchRunMs = elapsed().inMilliseconds - runStitchStartMs + vfLatencyStitchRunStat.add(stitchRunMs) + }) + .onSuccess( + BlenderVisibilityLibrary.scribeVisibilityVerdict( + _, + visibilityDeciderGates.enableVerdictScribingBVL, + verdictLogger, + bvRequest.viewerContext.userId, + bvRequest.safetyLevel)) + } + + def getContentId(viewerId: Option[Long], authorId: Long, tweet: Tweet): ContentId = { + if (viewerId.contains(authorId)) + TweetId(tweet.id) + else BlenderTweetId(tweet.id) + } + + def getAllVisibilityResults(bvRequest: BlenderVisibilityRequest): ( + Stitch[VisibilityResult], + Option[Stitch[VisibilityResult]], + Option[Stitch[VisibilityResult]] + ) = { + val tweetContentId = getContentId( + viewerId = bvRequest.viewerContext.userId, + authorId = bvRequest.tweet.coreData.get.userId, + tweet = bvRequest.tweet) + + val tweetFeatureMap = + buildFeatureMap(bvRequest, bvRequest.tweet, ORIGINAL) + vfEngineCounter.incr() + val requestTweetVisibilityResult = visibilityLibrary + .runRuleEngine( + tweetContentId, + tweetFeatureMap, + bvRequest.viewerContext, + bvRequest.safetyLevel + ).map(handleComposableVisibilityResult) + + val quotedTweetVisibilityResultOption = bvRequest.quotedTweet.map(quotedTweet => { + val quotedTweetContentId = getContentId( + viewerId = bvRequest.viewerContext.userId, + authorId = quotedTweet.coreData.get.userId, + tweet = quotedTweet) + + val quotedInnerTweetFeatureMap = + buildFeatureMap(bvRequest, quotedTweet, QUOTED) + vfEngineCounter.incr() + visibilityLibrary + .runRuleEngine( + quotedTweetContentId, + quotedInnerTweetFeatureMap, + bvRequest.viewerContext, + bvRequest.safetyLevel + ) + .map(handleComposableVisibilityResult) + .map(handleInnerQuotedTweetVisibilityResult) + }) + + val sourceTweetVisibilityResultOption = bvRequest.retweetSourceTweet.map(sourceTweet => { + val sourceTweetContentId = getContentId( + viewerId = bvRequest.viewerContext.userId, + authorId = sourceTweet.coreData.get.userId, + tweet = sourceTweet) + + val sourceTweetFeatureMap = + buildFeatureMap(bvRequest, sourceTweet, SOURCE) + vfEngineCounter.incr() + visibilityLibrary + .runRuleEngine( + sourceTweetContentId, + sourceTweetFeatureMap, + bvRequest.viewerContext, + bvRequest.safetyLevel + ) + .map(handleComposableVisibilityResult) + }) + + ( + requestTweetVisibilityResult, + quotedTweetVisibilityResultOption, + sourceTweetVisibilityResultOption) + } + + def buildFeatureMap( + bvRequest: BlenderVisibilityRequest, + tweet: Tweet, + tweetType: TweetType + ): FeatureMap = { + val authorId = tweet.coreData.get.userId + val viewerId = bvRequest.viewerContext.userId + val isRetweet = if (tweetType.equals(ORIGINAL)) bvRequest.isRetweet else false + val isSourceTweet = tweetType.equals(SOURCE) + val isQuotedTweet = tweetType.equals(QUOTED) + val tweetMediaKeys: Seq[GenericMediaKey] = tweet.media + .getOrElse(Seq.empty) + .flatMap(_.mediaKey.map(GenericMediaKey.apply)) + + visibilityLibrary.featureMapBuilder( + Seq( + viewerFeatures + .forViewerBlenderContext(bvRequest.blenderVFRequestContext, bvRequest.viewerContext), + relationshipFeatures.forAuthorId(authorId, viewerId), + fonsrRelationshipFeatures + .forTweetAndAuthorId(tweet = tweet, authorId = authorId, viewerId = viewerId), + tweetFeatures.forTweet(tweet), + mediaFeatures.forMediaKeys(tweetMediaKeys), + authorFeatures.forAuthorId(authorId), + blenderContextFeatures.forBlenderContext(bvRequest.blenderVFRequestContext), + _.withConstantFeature(TweetIsRetweet, isRetweet), + misinfoPolicyFeatures.forTweet(tweet, bvRequest.viewerContext), + exclusiveTweetFeatures.forTweet(tweet, bvRequest.viewerContext), + trustedFriendsTweetFeatures.forTweet(tweet, viewerId), + editTweetFeatures.forTweet(tweet), + _.withConstantFeature(TweetIsInnerQuotedTweet, isQuotedTweet), + _.withConstantFeature(TweetIsSourceTweet, isSourceTweet), + ) + ) + } + + def handleComposableVisibilityResult(result: VisibilityResult): VisibilityResult = { + if (result.secondaryVerdicts.nonEmpty) { + result.copy(verdict = composeActions(result.verdict, result.secondaryVerdicts)) + } else { + result + } + } + + private def composeActions(primary: Action, secondary: Seq[Action]): Action = { + if (primary.isComposable && secondary.nonEmpty) { + val actions = Seq[Action] { primary } ++ secondary + val interstitialOpt = Action.getFirstInterstitial(actions: _*) + val softInterventionOpt = Action.getFirstSoftIntervention(actions: _*) + val limitedEngagementsOpt = Action.getFirstLimitedEngagements(actions: _*) + val avoidOpt = Action.getFirstAvoid(actions: _*) + + val numActions = + Seq[Option[_]](interstitialOpt, softInterventionOpt, limitedEngagementsOpt, avoidOpt) + .count(_.isDefined) + if (numActions > 1) { + TweetInterstitial( + interstitialOpt, + softInterventionOpt, + limitedEngagementsOpt, + None, + avoidOpt + ) + } else { + primary + } + } else { + primary + } + } + + def handleInnerQuotedTweetVisibilityResult( + result: VisibilityResult + ): VisibilityResult = { + val newVerdict: Action = + result.verdict match { + case interstitial: Interstitial => Drop(interstitial.reason) + case ComposableActionsWithInterstitial(tweetInterstitial) => Drop(tweetInterstitial.reason) + case verdict => verdict + } + + result.copy(verdict = newVerdict) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/blender/BlenderVisibilityRequest.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/blender/BlenderVisibilityRequest.scala new file mode 100644 index 000000000..aa6e604a5 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/blender/BlenderVisibilityRequest.scala @@ -0,0 +1,42 @@ +package com.twitter.visibility.interfaces.blender + +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.ViewerContext +import com.twitter.visibility.interfaces.common.blender.BlenderVFRequestContext + +case class BlenderVisibilityRequest( + tweet: Tweet, + quotedTweet: Option[Tweet], + retweetSourceTweet: Option[Tweet] = None, + isRetweet: Boolean, + safetyLevel: SafetyLevel, + viewerContext: ViewerContext, + blenderVFRequestContext: BlenderVFRequestContext) { + + def getTweetID: Long = tweet.id + + def hasQuotedTweet: Boolean = { + quotedTweet.nonEmpty + } + def hasSourceTweet: Boolean = { + retweetSourceTweet.nonEmpty + } + + def getQuotedTweetId: Long = { + quotedTweet match { + case Some(qTweet) => + qTweet.id + case None => + -1 + } + } + def getSourceTweetId: Long = { + retweetSourceTweet match { + case Some(sourceTweet) => + sourceTweet.id + case None => + -1 + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/blender/CombinedVisibilityResult.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/blender/CombinedVisibilityResult.scala new file mode 100644 index 000000000..6868b67ec --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/blender/CombinedVisibilityResult.scala @@ -0,0 +1,7 @@ +package com.twitter.visibility.interfaces.blender + +import com.twitter.visibility.builder.VisibilityResult + +case class CombinedVisibilityResult( + tweetVisibilityResult: VisibilityResult, + quotedTweetVisibilityResult: Option[VisibilityResult]) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/cards/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/cards/BUILD new file mode 100644 index 000000000..b2b0f998f --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/cards/BUILD @@ -0,0 +1,17 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = False, + tags = ["bazel-compatible"], + dependencies = [ + "appsec/sanitization-lib/src/main/scala", + "src/thrift/com/twitter/expandodo:cards-scala", + "stitch/stitch-core", + "visibility/lib/src/main/resources/config", + "visibility/lib/src/main/scala/com/twitter/visibility", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/tweets", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/users", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/cards/CardVisibilityLibrary.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/cards/CardVisibilityLibrary.scala new file mode 100644 index 000000000..575356901 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/cards/CardVisibilityLibrary.scala @@ -0,0 +1,187 @@ +package com.twitter.visibility.interfaces.cards + +import com.twitter.appsec.sanitization.URLSafety +import com.twitter.decider.Decider +import com.twitter.servo.util.Gate +import com.twitter.stitch.Stitch +import com.twitter.tweetypie.{thriftscala => tweetypiethrift} +import com.twitter.util.Stopwatch +import com.twitter.visibility.VisibilityLibrary +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.builder.tweets.CommunityTweetFeatures +import com.twitter.visibility.builder.tweets.CommunityTweetFeaturesV2 +import com.twitter.visibility.builder.tweets.NilTweetLabelMaps +import com.twitter.visibility.builder.tweets.TweetFeatures +import com.twitter.visibility.builder.users.AuthorFeatures +import com.twitter.visibility.builder.users.RelationshipFeatures +import com.twitter.visibility.builder.users.ViewerFeatures +import com.twitter.visibility.common.CommunitiesSource +import com.twitter.visibility.common.UserId +import com.twitter.visibility.common.UserRelationshipSource +import com.twitter.visibility.common.UserSource +import com.twitter.visibility.configapi.configs.VisibilityDeciderGates +import com.twitter.visibility.features.CardIsPoll +import com.twitter.visibility.features.CardUriHost +import com.twitter.visibility.features.FeatureMap +import com.twitter.visibility.models.ContentId.CardId +import com.twitter.visibility.models.ViewerContext + +object CardVisibilityLibrary { + type Type = CardVisibilityRequest => Stitch[VisibilityResult] + + private[this] def getAuthorFeatures( + authorIdOpt: Option[Long], + authorFeatures: AuthorFeatures + ): FeatureMapBuilder => FeatureMapBuilder = { + authorIdOpt match { + case Some(authorId) => authorFeatures.forAuthorId(authorId) + case _ => authorFeatures.forNoAuthor() + } + } + + private[this] def getCardUriFeatures( + cardUri: String, + enableCardVisibilityLibraryCardUriParsing: Boolean, + trackCardUriHost: Option[String] => Unit + ): FeatureMapBuilder => FeatureMapBuilder = { + if (enableCardVisibilityLibraryCardUriParsing) { + val safeCardUriHost = URLSafety.getHostSafe(cardUri) + trackCardUriHost(safeCardUriHost) + + _.withConstantFeature(CardUriHost, safeCardUriHost) + } else { + identity + } + } + + private[this] def getRelationshipFeatures( + authorIdOpt: Option[Long], + viewerIdOpt: Option[Long], + relationshipFeatures: RelationshipFeatures + ): FeatureMapBuilder => FeatureMapBuilder = { + authorIdOpt match { + case Some(authorId) => relationshipFeatures.forAuthorId(authorId, viewerIdOpt) + case _ => relationshipFeatures.forNoAuthor() + } + } + + private[this] def getTweetFeatures( + tweetOpt: Option[tweetypiethrift.Tweet], + tweetFeatures: TweetFeatures + ): FeatureMapBuilder => FeatureMapBuilder = { + tweetOpt match { + case Some(tweet) => tweetFeatures.forTweet(tweet) + case _ => identity + } + } + + private[this] def getCommunityFeatures( + tweetOpt: Option[tweetypiethrift.Tweet], + viewerContext: ViewerContext, + communityTweetFeatures: CommunityTweetFeatures + ): FeatureMapBuilder => FeatureMapBuilder = { + tweetOpt match { + case Some(tweet) => communityTweetFeatures.forTweet(tweet, viewerContext) + case _ => identity + } + } + + def apply( + visibilityLibrary: VisibilityLibrary, + userSource: UserSource = UserSource.empty, + userRelationshipSource: UserRelationshipSource = UserRelationshipSource.empty, + communitiesSource: CommunitiesSource = CommunitiesSource.empty, + enableVfParityTest: Gate[Unit] = Gate.False, + enableVfFeatureHydration: Gate[Unit] = Gate.False, + decider: Decider + ): Type = { + val libraryStatsReceiver = visibilityLibrary.statsReceiver + val vfLatencyOverallStat = libraryStatsReceiver.stat("vf_latency_overall") + val vfLatencyStitchBuildStat = libraryStatsReceiver.stat("vf_latency_stitch_build") + val vfLatencyStitchRunStat = libraryStatsReceiver.stat("vf_latency_stitch_run") + val cardUriStats = libraryStatsReceiver.scope("card_uri") + val visibilityDeciderGates = VisibilityDeciderGates(decider) + + val authorFeatures = new AuthorFeatures(userSource, libraryStatsReceiver) + val viewerFeatures = new ViewerFeatures(userSource, libraryStatsReceiver) + val tweetFeatures = new TweetFeatures(NilTweetLabelMaps, libraryStatsReceiver) + val communityTweetFeatures = new CommunityTweetFeaturesV2( + communitiesSource = communitiesSource, + ) + val relationshipFeatures = + new RelationshipFeatures(userRelationshipSource, libraryStatsReceiver) + val parityTest = new CardVisibilityLibraryParityTest(libraryStatsReceiver) + + { r: CardVisibilityRequest => + val elapsed = Stopwatch.start() + var runStitchStartMs = 0L + + val viewerId: Option[UserId] = r.viewerContext.userId + + val featureMap = + visibilityLibrary + .featureMapBuilder( + Seq( + viewerFeatures.forViewerId(viewerId), + getAuthorFeatures(r.authorId, authorFeatures), + getCardUriFeatures( + cardUri = r.cardUri, + enableCardVisibilityLibraryCardUriParsing = + visibilityDeciderGates.enableCardVisibilityLibraryCardUriParsing(), + trackCardUriHost = { safeCardUriHost: Option[String] => + if (safeCardUriHost.isEmpty) { + cardUriStats.counter("empty").incr() + } + } + ), + getCommunityFeatures(r.tweetOpt, r.viewerContext, communityTweetFeatures), + getRelationshipFeatures(r.authorId, r.viewerContext.userId, relationshipFeatures), + getTweetFeatures(r.tweetOpt, tweetFeatures), + _.withConstantFeature(CardIsPoll, r.isPollCardType) + ) + ) + + val response = visibilityLibrary + .runRuleEngine( + CardId(r.cardUri), + featureMap, + r.viewerContext, + r.safetyLevel + ) + .onSuccess(_ => { + val overallStatMs = elapsed().inMilliseconds + vfLatencyOverallStat.add(overallStatMs) + val runStitchEndMs = elapsed().inMilliseconds + vfLatencyStitchRunStat.add(runStitchEndMs - runStitchStartMs) + }) + + runStitchStartMs = elapsed().inMilliseconds + val buildStitchStatMs = elapsed().inMilliseconds + vfLatencyStitchBuildStat.add(buildStitchStatMs) + + lazy val hydratedFeatureResponse: Stitch[VisibilityResult] = + FeatureMap.resolve(featureMap, libraryStatsReceiver).flatMap { resolvedFeatureMap => + visibilityLibrary.runRuleEngine( + CardId(r.cardUri), + resolvedFeatureMap, + r.viewerContext, + r.safetyLevel + ) + } + + val isVfParityTestEnabled = enableVfParityTest() + val isVfFeatureHydrationEnabled = enableVfFeatureHydration() + + if (!isVfParityTestEnabled && !isVfFeatureHydrationEnabled) { + response + } else if (isVfParityTestEnabled && !isVfFeatureHydrationEnabled) { + response.applyEffect { resp => + Stitch.async(parityTest.runParityTest(hydratedFeatureResponse, resp)) + } + } else { + hydratedFeatureResponse + } + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/cards/CardVisibilityLibraryParityTest.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/cards/CardVisibilityLibraryParityTest.scala new file mode 100644 index 000000000..4dc3f6baf --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/cards/CardVisibilityLibraryParityTest.scala @@ -0,0 +1,35 @@ +package com.twitter.visibility.interfaces.cards + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.stitch.Stitch +import com.twitter.visibility.builder.VisibilityResult + +class CardVisibilityLibraryParityTest(statsReceiver: StatsReceiver) { + private val parityTestScope = statsReceiver.scope("card_visibility_library_parity") + private val requests = parityTestScope.counter("requests") + private val equal = parityTestScope.counter("equal") + private val incorrect = parityTestScope.counter("incorrect") + private val failures = parityTestScope.counter("failures") + + def runParityTest( + preHydratedFeatureVisibilityResult: Stitch[VisibilityResult], + resp: VisibilityResult + ): Stitch[Unit] = { + requests.incr() + + preHydratedFeatureVisibilityResult + .flatMap { parityResponse => + if (parityResponse.verdict == resp.verdict) { + equal.incr() + } else { + incorrect.incr() + } + + Stitch.Done + }.rescue { + case t: Throwable => + failures.incr() + Stitch.Done + }.unit + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/cards/CardVisibilityRequest.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/cards/CardVisibilityRequest.scala new file mode 100644 index 000000000..c3cef91d2 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/cards/CardVisibilityRequest.scala @@ -0,0 +1,13 @@ +package com.twitter.visibility.interfaces.cards + +import com.twitter.tweetypie.{thriftscala => tweetypiethrift} +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.ViewerContext + +case class CardVisibilityRequest( + cardUri: String, + authorId: Option[Long], + tweetOpt: Option[tweetypiethrift.Tweet], + isPollCardType: Boolean, + safetyLevel: SafetyLevel, + viewerContext: ViewerContext) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/BUILD.bazel b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/BUILD.bazel new file mode 100644 index 000000000..e8d9d1b25 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/BUILD.bazel @@ -0,0 +1,15 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + strict_deps = True, + dependencies = [ + "src/scala/com/twitter/search/blender/services/strato", + "src/thrift/com/twitter/spam/rtf:safety-label-scala", + "src/thrift/com/twitter/spam/rtf:tweet-rtf-event-scala", + "stitch/stitch-core", + "strato/src/main/scala/com/twitter/strato/catalog", + "strato/src/main/scala/com/twitter/strato/client", + "strato/src/main/scala/com/twitter/strato/data", + "strato/src/main/scala/com/twitter/strato/thrift", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/blender/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/blender/BUILD new file mode 100644 index 000000000..c0f745c39 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/blender/BUILD @@ -0,0 +1,14 @@ +scala_library( + sources = [ + "*.scala", + ], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "src/scala/com/twitter/search/blender/services/strato", + "src/thrift/com/twitter/search/common:constants-scala", + "stitch/stitch-core", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/blender/BlenderVFRequestContext.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/blender/BlenderVFRequestContext.scala new file mode 100644 index 000000000..05e98c17d --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/blender/BlenderVFRequestContext.scala @@ -0,0 +1,19 @@ +package com.twitter.visibility.interfaces.common.blender + +import com.twitter.search.blender.services.strato.UserSearchSafetySettings +import com.twitter.search.common.constants.thriftscala.ThriftQuerySource + +case class BlenderVFRequestContext( + resultsPageNumber: Int, + candidateCount: Int, + querySourceOption: Option[ThriftQuerySource], + userSearchSafetySettings: UserSearchSafetySettings, + queryHasUser: Boolean = false) { + + def this( + resultsPageNumber: Int, + candidateCount: Int, + querySourceOption: Option[ThriftQuerySource], + userSearchSafetySettings: UserSearchSafetySettings + ) = this(resultsPageNumber, candidateCount, querySourceOption, userSearchSafetySettings, false) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/search/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/search/BUILD new file mode 100644 index 000000000..c0f745c39 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/search/BUILD @@ -0,0 +1,14 @@ +scala_library( + sources = [ + "*.scala", + ], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "src/scala/com/twitter/search/blender/services/strato", + "src/thrift/com/twitter/search/common:constants-scala", + "stitch/stitch-core", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/search/SearchVFRequestContext.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/search/SearchVFRequestContext.scala new file mode 100644 index 000000000..ef06b0b3b --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/search/SearchVFRequestContext.scala @@ -0,0 +1,19 @@ +package com.twitter.visibility.interfaces.common.search + +import com.twitter.search.blender.services.strato.UserSearchSafetySettings +import com.twitter.search.common.constants.thriftscala.ThriftQuerySource + +case class SearchVFRequestContext( + resultsPageNumber: Int, + candidateCount: Int, + querySourceOption: Option[ThriftQuerySource], + userSearchSafetySettings: UserSearchSafetySettings, + queryHasUser: Boolean = false) { + + def this( + resultsPageNumber: Int, + candidateCount: Int, + querySourceOption: Option[ThriftQuerySource], + userSearchSafetySettings: UserSearchSafetySettings + ) = this(resultsPageNumber, candidateCount, querySourceOption, userSearchSafetySettings, false) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/tweets/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/tweets/BUILD new file mode 100644 index 000000000..64400fae2 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/tweets/BUILD @@ -0,0 +1,17 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "src/scala/com/twitter/search/blender/services/strato", + "src/thrift/com/twitter/spam/rtf:safety-label-scala", + "src/thrift/com/twitter/spam/rtf:tweet-rtf-event-scala", + "stitch/stitch-core", + "strato/src/main/scala/com/twitter/strato/catalog", + "strato/src/main/scala/com/twitter/strato/client", + "strato/src/main/scala/com/twitter/strato/data", + "strato/src/main/scala/com/twitter/strato/thrift", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/tweets/StratoSafetyLabelFetcher.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/tweets/StratoSafetyLabelFetcher.scala new file mode 100644 index 000000000..324620892 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/tweets/StratoSafetyLabelFetcher.scala @@ -0,0 +1,18 @@ +package com.twitter.visibility.interfaces.common.tweets + +import com.twitter.spam.rtf.thriftscala.SafetyLabel +import com.twitter.spam.rtf.thriftscala.SafetyLabelType +import com.twitter.strato.client.Fetcher +import com.twitter.strato.client.{Client => StratoClient} +import com.twitter.strato.thrift.ScroogeConvImplicits._ +import com.twitter.util.Memoize + +object StratoSafetyLabelFetcher { + def apply(client: StratoClient): SafetyLabelFetcherType = { + val getFetcher: SafetyLabelType => Fetcher[Long, Unit, SafetyLabel] = + Memoize((safetyLabelType: SafetyLabelType) => + client.fetcher[Long, SafetyLabel](s"visibility/${safetyLabelType.name}.Tweet")) + + (tweetId, safetyLabelType) => getFetcher(safetyLabelType).fetch(tweetId).map(_.v) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/tweets/StratoSafetyLabelMapFetcher.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/tweets/StratoSafetyLabelMapFetcher.scala new file mode 100644 index 000000000..a76ad92a9 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/tweets/StratoSafetyLabelMapFetcher.scala @@ -0,0 +1,17 @@ +package com.twitter.visibility.interfaces.common.tweets + +import com.twitter.spam.rtf.thriftscala.SafetyLabelMap +import com.twitter.strato.client.Fetcher +import com.twitter.strato.client.{Client => StratoClient} +import com.twitter.strato.thrift.ScroogeConvImplicits._ + +object StratoSafetyLabelMapFetcher { + val column = "visibility/baseTweetSafetyLabelMap" + + def apply(client: StratoClient): SafetyLabelMapFetcherType = { + val fetcher: Fetcher[Long, Unit, SafetyLabelMap] = + client.fetcher[Long, SafetyLabelMap](column) + + tweetId => fetcher.fetch(tweetId).map(_.v) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/tweets/package.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/tweets/package.scala new file mode 100644 index 000000000..82a8033d6 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/common/tweets/package.scala @@ -0,0 +1,13 @@ +package com.twitter.visibility.interfaces.common + +import com.twitter.search.blender.services.strato.UserSearchSafetySettings +import com.twitter.spam.rtf.thriftscala.SafetyLabel +import com.twitter.spam.rtf.thriftscala.SafetyLabelMap +import com.twitter.spam.rtf.thriftscala.SafetyLabelType +import com.twitter.stitch.Stitch + +package object tweets { + type SafetyLabelFetcherType = (Long, SafetyLabelType) => Stitch[Option[SafetyLabel]] + type SafetyLabelMapFetcherType = Long => Stitch[Option[SafetyLabelMap]] + type UserSearchSafetySettingsFetcherType = Long => Stitch[UserSearchSafetySettings] +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/AdAvoidanceLibrary.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/AdAvoidanceLibrary.scala new file mode 100644 index 000000000..067fa833f --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/AdAvoidanceLibrary.scala @@ -0,0 +1,158 @@ +package com.twitter.visibility.interfaces.conversations + +import com.google.common.annotations.VisibleForTesting +import com.twitter.decider.Decider +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.gizmoduck.thriftscala.User +import com.twitter.spam.rtf.thriftscala.SafetyLevel +import com.twitter.stitch.Stitch +import com.twitter.tweetypie.thriftscala.GetTweetFieldsResult +import com.twitter.tweetypie.thriftscala.TweetFieldsResultFound +import com.twitter.tweetypie.thriftscala.TweetFieldsResultState +import com.twitter.util.Stopwatch +import com.twitter.visibility.VisibilityLibrary +import com.twitter.visibility.common.filtered_reason.FilteredReasonHelper +import com.twitter.visibility.models.ViewerContext +import com.twitter.visibility.rules.Interstitial +import com.twitter.visibility.rules.Tombstone + +case class AdAvoidanceRequest( + conversationId: Long, + focalTweetId: Long, + tweets: Seq[(GetTweetFieldsResult, Option[SafetyLevel])], + authorMap: Map[ + Long, + User + ], + moderatedTweetIds: Seq[Long], + viewerContext: ViewerContext, + useRichText: Boolean = true) + +case class AdAvoidanceResponse(dropAd: Map[Long, Boolean]) + +object AdAvoidanceLibrary { + type Type = + AdAvoidanceRequest => Stitch[AdAvoidanceResponse] + + private def shouldAvoid( + result: TweetFieldsResultState, + tombstoneOpt: Option[VfTombstone], + statsReceiver: StatsReceiver + ): Boolean = { + shouldAvoid(result, statsReceiver) || shouldAvoid(tombstoneOpt, statsReceiver) + } + + private def shouldAvoid( + result: TweetFieldsResultState, + statsReceiver: StatsReceiver + ): Boolean = { + result match { + case TweetFieldsResultState.Found(TweetFieldsResultFound(_, _, Some(filteredReason))) + if FilteredReasonHelper.isAvoid(filteredReason) => + statsReceiver.counter("avoid").incr() + true + case _ => false + } + } + + private def shouldAvoid( + tombstoneOpt: Option[VfTombstone], + statsReceiver: StatsReceiver, + ): Boolean = { + tombstoneOpt + .map(_.action).collect { + case Tombstone(epitaph, _) => + statsReceiver.scope("tombstone").counter(epitaph.name).incr() + true + case interstitial: Interstitial => + statsReceiver.scope("interstitial").counter(interstitial.reason.name).incr() + true + case _ => false + }.getOrElse(false) + } + + private def runTombstoneVisLib( + request: AdAvoidanceRequest, + tombstoneVisibilityLibrary: TombstoneVisibilityLibrary, + ): Stitch[TombstoneVisibilityResponse] = { + val tombstoneRequest = TombstoneVisibilityRequest( + conversationId = request.conversationId, + focalTweetId = request.focalTweetId, + tweets = request.tweets, + authorMap = request.authorMap, + moderatedTweetIds = request.moderatedTweetIds, + viewerContext = request.viewerContext, + useRichText = request.useRichText + ) + + tombstoneVisibilityLibrary(tombstoneRequest) + } + + def buildTweetAdAvoidanceMap(tweets: Seq[GetTweetFieldsResult]): Map[Long, Boolean] = tweets + .map(tweet => { + val shouldAvoid = tweet.tweetResult match { + case TweetFieldsResultState.Found(TweetFieldsResultFound(_, _, Some(filteredReason))) => + FilteredReasonHelper.isAvoid(filteredReason) + case _ => false + } + + tweet.tweetId -> shouldAvoid + }).toMap + + def apply(visibilityLibrary: VisibilityLibrary, decider: Decider): Type = { + val tvl = + TombstoneVisibilityLibrary(visibilityLibrary, visibilityLibrary.statsReceiver, decider) + buildLibrary(tvl, visibilityLibrary.statsReceiver) + } + + @VisibleForTesting + def buildLibrary( + tvl: TombstoneVisibilityLibrary, + libraryStatsReceiver: StatsReceiver + ): AdAvoidanceLibrary.Type = { + + val statsReceiver = libraryStatsReceiver.scope("AdAvoidanceLibrary") + val reasonsStatsReceiver = statsReceiver.scope("reasons") + val latencyStatsReceiver = statsReceiver.scope("latency") + val vfLatencyOverallStat = latencyStatsReceiver.stat("vf_latency_overall") + val vfLatencyStitchBuildStat = latencyStatsReceiver.stat("vf_latency_stitch_build") + val vfLatencyStitchRunStat = latencyStatsReceiver.stat("vf_latency_stitch_run") + + request: AdAvoidanceRequest => { + val elapsed = Stopwatch.start() + + var runStitchStartMs = 0L + + val tombstoneResponse: Stitch[TombstoneVisibilityResponse] = + runTombstoneVisLib(request, tvl) + + val response = tombstoneResponse + .map({ response: TombstoneVisibilityResponse => + statsReceiver.counter("requests").incr(request.tweets.size) + + val dropResults: Seq[(Long, Boolean)] = request.tweets.map(tweetAndSafetyLevel => { + val tweet = tweetAndSafetyLevel._1 + tweet.tweetId -> + shouldAvoid( + tweet.tweetResult, + response.tweetVerdicts.get(tweet.tweetId), + reasonsStatsReceiver) + }) + + AdAvoidanceResponse(dropAd = dropResults.toMap) + }) + .onSuccess(_ => { + val overallStatMs = elapsed().inMilliseconds + vfLatencyOverallStat.add(overallStatMs) + val runStitchEndMs = elapsed().inMilliseconds + vfLatencyStitchRunStat.add(runStitchEndMs - runStitchStartMs) + }) + + runStitchStartMs = elapsed().inMilliseconds + val buildStitchStatMs = elapsed().inMilliseconds + vfLatencyStitchBuildStat.add(buildStitchStatMs) + + response + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/BUILD new file mode 100644 index 000000000..eae843def --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/BUILD @@ -0,0 +1,46 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "3rdparty/jvm/com/google/guava", + "3rdparty/jvm/com/ibm/icu:icu4j", + "decider/src/main/scala", + "servo/decider/src/main/scala", + "servo/repo", + "src/thrift/com/twitter/context:twitter-context-scala", + "src/thrift/com/twitter/gizmoduck:thrift-scala", + "src/thrift/com/twitter/spam/rtf:safety-level-scala", + "src/thrift/com/twitter/spam/rtf:tweet-rtf-event-scala", + "src/thrift/com/twitter/timelines/render:thrift-scala", + "src/thrift/com/twitter/tweetypie:service-scala", + "src/thrift/com/twitter/tweetypie:tweet-scala", + "stitch/stitch-core", + "translation/src/main/scala", + "twitter-config/yaml", + "twitter-context/src/main/scala", + "urt/richtext/src/main/scala/com/twitter/urt/richtext", + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/common/src/main/scala/com/twitter/visibility/common/actions", + "visibility/common/src/main/scala/com/twitter/visibility/common/filtered_reason", + "visibility/common/src/main/thrift/com/twitter/visibility:action-scala", + "visibility/lib/src/main/resources/config", + "visibility/lib/src/main/scala/com/twitter/visibility", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/tweets", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/users", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/configs", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/params", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/common/tweets", + "visibility/lib/src/main/thrift/com/twitter/visibility/logging:vf-logging-scala", + "visibility/results/src/main/scala/com/twitter/visibility/results/richtext", + "visibility/results/src/main/scala/com/twitter/visibility/results/translation", + "visibility/results/src/main/scala/com/twitter/visibility/results/urt", + ], + exports = [ + "visibility/common/src/main/thrift/com/twitter/visibility:action-scala", + "visibility/lib/src/main/scala/com/twitter/visibility", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/TimelineConversationsVisibilityLibrary.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/TimelineConversationsVisibilityLibrary.scala new file mode 100644 index 000000000..cefe6d762 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/TimelineConversationsVisibilityLibrary.scala @@ -0,0 +1,260 @@ +package com.twitter.visibility.interfaces.conversations + +import com.twitter.decider.Decider +import com.twitter.finagle.stats.NullStatsReceiver +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.gizmoduck.thriftscala.Label +import com.twitter.servo.repository.KeyValueResult +import com.twitter.servo.util.Gate +import com.twitter.spam.rtf.thriftscala.SafetyLabel +import com.twitter.spam.rtf.thriftscala.SafetyLabelType +import com.twitter.spam.rtf.thriftscala.SafetyLabelValue +import com.twitter.stitch.Stitch +import com.twitter.util.Future +import com.twitter.util.Return +import com.twitter.util.Stopwatch +import com.twitter.util.Try +import com.twitter.visibility.VisibilityLibrary +import com.twitter.visibility.builder.tweets.TweetIdFeatures +import com.twitter.visibility.builder.FeatureMapBuilder +import com.twitter.visibility.builder.VerdictLogger +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.builder.tweets.FosnrPefetchedLabelsRelationshipFeatures +import com.twitter.visibility.builder.users.AuthorFeatures +import com.twitter.visibility.common.UserRelationshipSource +import com.twitter.visibility.common.UserSource +import com.twitter.visibility.configapi.configs.VisibilityDeciderGates +import com.twitter.visibility.features.AuthorUserLabels +import com.twitter.visibility.features.ConversationRootAuthorIsVerified +import com.twitter.visibility.features.FeatureMap +import com.twitter.visibility.features.HasInnerCircleOfFriendsRelationship +import com.twitter.visibility.features.TweetConversationId +import com.twitter.visibility.features.TweetParentId +import com.twitter.visibility.logging.thriftscala.VFLibType +import com.twitter.visibility.models.ContentId.TweetId +import com.twitter.visibility.models.SafetyLevel.TimelineConversationsDownranking +import com.twitter.visibility.models.SafetyLevel.TimelineConversationsDownrankingMinimal +import com.twitter.visibility.models.SafetyLevel.toThrift +import com.twitter.visibility.models.ContentId +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.TweetSafetyLabel +import com.twitter.visibility.models.UnitOfDiversion + +object TimelineConversationsVisibilityLibrary { + type Type = + TimelineConversationsVisibilityRequest => Stitch[TimelineConversationsVisibilityResponse] + + def apply( + visibilityLibrary: VisibilityLibrary, + batchSafetyLabelRepository: BatchSafetyLabelRepository, + decider: Decider, + userRelationshipSource: UserRelationshipSource = UserRelationshipSource.empty, + userSource: UserSource = UserSource.empty + ): Type = { + val libraryStatsReceiver = visibilityLibrary.statsReceiver + val tweetIdFeatures = new TweetIdFeatures( + statsReceiver = libraryStatsReceiver, + enableStitchProfiling = Gate.False + ) + val tweetIdFeaturesMinimal = new TweetIdFeatures( + statsReceiver = libraryStatsReceiver, + enableStitchProfiling = Gate.False + ) + val vfLatencyOverallStat = libraryStatsReceiver.stat("vf_latency_overall") + val vfLatencyStitchBuildStat = libraryStatsReceiver.stat("vf_latency_stitch_build") + val vfLatencyStitchRunStat = libraryStatsReceiver.stat("vf_latency_stitch_run") + + val visibilityDeciderGates = VisibilityDeciderGates(decider) + val verdictLogger = + createVerdictLogger( + visibilityDeciderGates.enableVerdictLoggerTCVL, + decider, + libraryStatsReceiver) + + request: TimelineConversationsVisibilityRequest => + val elapsed = Stopwatch.start() + var runStitchStartMs = 0L + + val future = request.prefetchedSafetyLabels match { + case Some(labels) => Future.value(labels) + case _ => + batchSafetyLabelRepository((request.conversationId, request.tweetIds)) + } + + val fosnrPefetchedLabelsRelationshipFeatures = + new FosnrPefetchedLabelsRelationshipFeatures( + userRelationshipSource = userRelationshipSource, + statsReceiver = libraryStatsReceiver) + + val authorFeatures = new AuthorFeatures(userSource, libraryStatsReceiver) + + Stitch.callFuture(future).flatMap { + kvr: KeyValueResult[Long, scala.collection.Map[SafetyLabelType, SafetyLabel]] => + val featureMapProvider: (ContentId, SafetyLevel) => FeatureMap = { + case (TweetId(tweetId), safetyLevel) => + val constantTweetSafetyLabels: Seq[TweetSafetyLabel] = + kvr.found.getOrElse(tweetId, Map.empty).toSeq.map { + case (safetyLabelType, safetyLabel) => + TweetSafetyLabel.fromThrift(SafetyLabelValue(safetyLabelType, safetyLabel)) + } + + val replyAuthor = request.tweetAuthors.flatMap { + _(tweetId) match { + case Return(Some(userId)) => Some(userId) + case _ => None + } + } + + val fosnrPefetchedLabelsRelationshipFeatureConf = replyAuthor match { + case Some(authorId) if visibilityLibrary.isReleaseCandidateEnabled => + fosnrPefetchedLabelsRelationshipFeatures + .forTweetWithSafetyLabelsAndAuthorId( + safetyLabels = constantTweetSafetyLabels, + authorId = authorId, + viewerId = request.viewerContext.userId) + case _ => fosnrPefetchedLabelsRelationshipFeatures.forNonFosnr() + } + + val authorFeatureConf = replyAuthor match { + case Some(authorId) if visibilityLibrary.isReleaseCandidateEnabled => + authorFeatures.forAuthorId(authorId) + case _ => authorFeatures.forNoAuthor() + } + + val baseBuilderArguments = (safetyLevel match { + case TimelineConversationsDownranking => + Seq(tweetIdFeatures.forTweetId(tweetId, constantTweetSafetyLabels)) + case TimelineConversationsDownrankingMinimal => + Seq(tweetIdFeaturesMinimal.forTweetId(tweetId, constantTweetSafetyLabels)) + case _ => Nil + }) :+ fosnrPefetchedLabelsRelationshipFeatureConf :+ authorFeatureConf + + val tweetAuthorUserLabels: Option[Seq[Label]] = + request.prefetchedTweetAuthorUserLabels.flatMap { + _.apply(tweetId) match { + case Return(Some(labelMap)) => + Some(labelMap.values.toSeq) + case _ => + None + } + } + + val hasInnerCircleOfFriendsRelationship: Boolean = + request.innerCircleOfFriendsRelationships match { + case Some(keyValueResult) => + keyValueResult(tweetId) match { + case Return(Some(true)) => true + case _ => false + } + case None => false + } + + val builderArguments: Seq[FeatureMapBuilder => FeatureMapBuilder] = + tweetAuthorUserLabels match { + case Some(labels) => + baseBuilderArguments :+ { (fmb: FeatureMapBuilder) => + fmb.withConstantFeature(AuthorUserLabels, labels) + } + + case None => + baseBuilderArguments :+ { (fmb: FeatureMapBuilder) => + fmb.withConstantFeature(AuthorUserLabels, Seq.empty) + } + case _ => + baseBuilderArguments + } + + val tweetParentIdOpt: Option[Long] = + request.tweetParentIdMap.flatMap(tweetParentIdMap => tweetParentIdMap(tweetId)) + + visibilityLibrary.featureMapBuilder(builderArguments :+ { (fmb: FeatureMapBuilder) => + fmb.withConstantFeature( + HasInnerCircleOfFriendsRelationship, + hasInnerCircleOfFriendsRelationship) + fmb.withConstantFeature(TweetConversationId, request.conversationId) + fmb.withConstantFeature(TweetParentId, tweetParentIdOpt) + fmb.withConstantFeature( + ConversationRootAuthorIsVerified, + request.rootAuthorIsVerified) + }) + case _ => + visibilityLibrary.featureMapBuilder(Nil) + } + val safetyLevel = + if (request.minimalSectioningOnly) TimelineConversationsDownrankingMinimal + else TimelineConversationsDownranking + + val evaluationContextBuilder = visibilityLibrary + .evaluationContextBuilder(request.viewerContext) + .withUnitOfDiversion(UnitOfDiversion.ConversationId(request.conversationId)) + + visibilityLibrary + .runRuleEngineBatch( + request.tweetIds.map(TweetId), + featureMapProvider, + evaluationContextBuilder, + safetyLevel + ) + .map { results: Seq[Try[VisibilityResult]] => + val (succeededRequests, _) = results.partition(_.exists(_.finished)) + val visibilityResultMap = succeededRequests.flatMap { + case Return(result) => + scribeVisibilityVerdict( + result, + visibilityDeciderGates.enableVerdictScribingTCVL, + verdictLogger, + request.viewerContext.userId, + safetyLevel) + result.contentId match { + case TweetId(id) => Some((id, result)) + case _ => None + } + case _ => None + }.toMap + val failedTweetIds = request.tweetIds diff visibilityResultMap.keys.toSeq + val response = TimelineConversationsVisibilityResponse( + visibilityResults = visibilityResultMap, + failedTweetIds = failedTweetIds + ) + + runStitchStartMs = elapsed().inMilliseconds + val buildStitchStatMs = elapsed().inMilliseconds + vfLatencyStitchBuildStat.add(buildStitchStatMs) + + response + } + .onSuccess(_ => { + val overallStatMs = elapsed().inMilliseconds + vfLatencyOverallStat.add(overallStatMs) + val runStitchEndMs = elapsed().inMilliseconds + vfLatencyStitchRunStat.add(runStitchEndMs - runStitchStartMs) + }) + } + } + + def scribeVisibilityVerdict( + visibilityResult: VisibilityResult, + enableVerdictScribing: Gate[Unit], + verdictLogger: VerdictLogger, + viewerId: Option[Long], + safetyLevel: SafetyLevel + ): Unit = if (enableVerdictScribing()) { + verdictLogger.scribeVerdict( + visibilityResult = visibilityResult, + viewerId = viewerId, + safetyLevel = toThrift(safetyLevel), + vfLibType = VFLibType.TimelineConversationsVisibilityLibrary) + } + + def createVerdictLogger( + enableVerdictLogger: Gate[Unit], + decider: Decider, + statsReceiver: StatsReceiver + ): VerdictLogger = { + if (enableVerdictLogger()) { + VerdictLogger(statsReceiver, decider) + } else { + VerdictLogger.Empty + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/TimelineConversationsVisibilityRequest.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/TimelineConversationsVisibilityRequest.scala new file mode 100644 index 000000000..217296f8b --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/TimelineConversationsVisibilityRequest.scala @@ -0,0 +1,20 @@ +package com.twitter.visibility.interfaces.conversations + +import com.twitter.gizmoduck.thriftscala.Label +import com.twitter.gizmoduck.thriftscala.LabelValue +import com.twitter.servo.repository.KeyValueResult +import com.twitter.spam.rtf.thriftscala.SafetyLabel +import com.twitter.spam.rtf.thriftscala.SafetyLabelType +import com.twitter.visibility.models.ViewerContext + +case class TimelineConversationsVisibilityRequest( + conversationId: Long, + tweetIds: Seq[Long], + viewerContext: ViewerContext, + minimalSectioningOnly: Boolean = false, + prefetchedSafetyLabels: Option[KeyValueResult[Long, Map[SafetyLabelType, SafetyLabel]]] = None, + prefetchedTweetAuthorUserLabels: Option[KeyValueResult[Long, Map[LabelValue, Label]]] = None, + innerCircleOfFriendsRelationships: Option[KeyValueResult[Long, Boolean]] = None, + tweetParentIdMap: Option[Map[Long, Option[Long]]] = None, + rootAuthorIsVerified: Boolean = false, + tweetAuthors: Option[KeyValueResult[Long, Long]] = None) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/TimelineConversationsVisibilityResponse.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/TimelineConversationsVisibilityResponse.scala new file mode 100644 index 000000000..f086c792c --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/TimelineConversationsVisibilityResponse.scala @@ -0,0 +1,7 @@ +package com.twitter.visibility.interfaces.conversations + +import com.twitter.visibility.builder.VisibilityResult + +case class TimelineConversationsVisibilityResponse( + visibilityResults: Map[Long, VisibilityResult], + failedTweetIds: Seq[Long]) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/Tombstone.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/Tombstone.scala new file mode 100644 index 000000000..1d3274ed2 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/Tombstone.scala @@ -0,0 +1,35 @@ +package com.twitter.visibility.interfaces.conversations + +import com.twitter.timelines.render.thriftscala.TombstoneDisplayType +import com.twitter.timelines.render.thriftscala.TombstoneInfo +import com.twitter.visibility.rules._ + +case class VfTombstone( + tombstoneId: Long, + includeTweet: Boolean, + action: Action, + tombstoneInfo: Option[TombstoneInfo] = None, + tombstoneDisplayType: TombstoneDisplayType = TombstoneDisplayType.Inline, + truncateDescendantsWhenFocal: Boolean = false) { + + val isTruncatable: Boolean = action match { + case Interstitial(Reason.ViewerBlocksAuthor, _, _) => true + case Interstitial(Reason.ViewerHardMutedAuthor, _, _) => true + case Interstitial(Reason.MutedKeyword, _, _) => true + case Tombstone(Epitaph.NotFound, _) => true + case Tombstone(Epitaph.Unavailable, _) => true + case Tombstone(Epitaph.Suspended, _) => true + case Tombstone(Epitaph.Protected, _) => true + case Tombstone(Epitaph.Deactivated, _) => true + case Tombstone(Epitaph.BlockedBy, _) => true + case Tombstone(Epitaph.Moderated, _) => true + case Tombstone(Epitaph.Deleted, _) => true + case Tombstone(Epitaph.Underage, _) => true + case Tombstone(Epitaph.NoStatedAge, _) => true + case Tombstone(Epitaph.LoggedOutAge, _) => true + case Tombstone(Epitaph.SuperFollowsContent, _) => true + case Tombstone(Epitaph.CommunityTweetHidden, _) => true + case _: LocalizedTombstone => true + case _ => false + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/TombstoneVisibilityLibrary.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/TombstoneVisibilityLibrary.scala new file mode 100644 index 000000000..3f228670d --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/TombstoneVisibilityLibrary.scala @@ -0,0 +1,633 @@ +package com.twitter.visibility.interfaces.conversations + +import com.twitter.decider.Decider +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.gizmoduck.thriftscala.User +import com.twitter.spam.rtf.thriftscala.FilteredReason +import com.twitter.spam.rtf.thriftscala.FilteredReason.UnspecifiedReason +import com.twitter.spam.rtf.thriftscala.SafetyLevel +import com.twitter.spam.rtf.thriftscala.SafetyResult +import com.twitter.stitch.Stitch +import com.twitter.timelines.render.thriftscala.RichText +import com.twitter.timelines.render.thriftscala.TombstoneDisplayType +import com.twitter.timelines.render.thriftscala.TombstoneInfo +import com.twitter.tweetypie.thriftscala.GetTweetFieldsResult +import com.twitter.tweetypie.thriftscala.TweetFieldsResultFailed +import com.twitter.tweetypie.thriftscala.TweetFieldsResultFiltered +import com.twitter.tweetypie.thriftscala.TweetFieldsResultFound +import com.twitter.tweetypie.thriftscala.TweetFieldsResultNotFound +import com.twitter.tweetypie.thriftscala.TweetFieldsResultState +import com.twitter.visibility.VisibilityLibrary +import com.twitter.visibility.builder.tweets.ModerationFeatures +import com.twitter.visibility.builder.users.AuthorFeatures +import com.twitter.visibility.builder.users.RelationshipFeatures +import com.twitter.visibility.builder.users.ViewerFeatures +import com.twitter.visibility.common.UserRelationshipSource +import com.twitter.visibility.common.UserSource +import com.twitter.visibility.common.actions.InterstitialReason +import com.twitter.visibility.common.actions.LimitedEngagementReason +import com.twitter.visibility.common.actions.TombstoneReason +import com.twitter.visibility.common.actions.converter.scala.InterstitialReasonConverter +import com.twitter.visibility.common.actions.converter.scala.LocalizedMessageConverter +import com.twitter.visibility.common.actions.converter.scala.TombstoneReasonConverter +import com.twitter.visibility.common.filtered_reason.FilteredReasonHelper +import com.twitter.visibility.configapi.configs.VisibilityDeciderGates +import com.twitter.visibility.features.FocalTweetId +import com.twitter.visibility.features.TweetId +import com.twitter.visibility.models.ContentId +import com.twitter.visibility.models.SafetyLevel.Tombstoning +import com.twitter.visibility.models.ViewerContext +import com.twitter.visibility.results.richtext.EpitaphToRichText +import com.twitter.visibility.results.richtext.LocalizedMessageToRichText +import com.twitter.visibility.results.urt.ReasonToUrtParser +import com.twitter.visibility.results.urt.SafetyResultToUrtParser +import com.twitter.visibility.rules._ +import com.twitter.visibility.{thriftscala => t} + +case class TombstoneVisibilityRequest( + conversationId: Long, + focalTweetId: Long, + tweets: Seq[(GetTweetFieldsResult, Option[SafetyLevel])], + authorMap: Map[ + Long, + User + ], + moderatedTweetIds: Seq[Long], + viewerContext: ViewerContext, + useRichText: Boolean = true) + +case class TombstoneVisibilityResponse(tweetVerdicts: Map[Long, VfTombstone]) + +case class TombstoneVisibilityLibrary( + visibilityLibrary: VisibilityLibrary, + statsReceiver: StatsReceiver, + decider: Decider) { + + private case class TombstoneType( + tweetId: Long, + tombstoneId: Long, + action: Action) { + + lazy val isInnerTombstone: Boolean = tweetId != tombstoneId + + lazy val tombstoneDisplayType: TombstoneDisplayType = action match { + case _: InterstitialLimitedEngagements | _: EmergencyDynamicInterstitial => + TombstoneDisplayType.NonCompliant + case _ => TombstoneDisplayType.Inline + } + } + + val En: String = "en" + val View: String = "View" + val relationshipFeatures = + new RelationshipFeatures( + statsReceiver) + val visibilityDeciderGates = VisibilityDeciderGates(decider) + + + def toAction( + filteredReason: FilteredReason, + actionStatsReceiver: StatsReceiver + ): Option[Action] = { + + val enableLocalizedInterstitials = + visibilityDeciderGates.enableConvosLocalizedInterstitial() + val enableLegacyInterstitials = + visibilityDeciderGates.enableConvosLegacyInterstitial() + + val tombstoneStatsReceiver = actionStatsReceiver.scope("tombstone") + val interstitialLocalStatsReceiver = + actionStatsReceiver.scope("interstitial").scope("localized") + val interstitialLegacyStatsReceiver = + actionStatsReceiver.scope("interstitial").scope("legacy") + + filteredReason match { + case _ if FilteredReasonHelper.isTombstone(filteredReason) => + createLocalizedTombstone(filteredReason, tombstoneStatsReceiver) match { + case tombstoneOpt @ Some(LocalizedTombstone(_, _)) => tombstoneOpt + case _ => + createTombstone(Epitaph.Unavailable, tombstoneStatsReceiver, Some("emptyTombstone")) + } + + case _ + if enableLocalizedInterstitials && + FilteredReasonHelper.isLocalizedSuppressedReasonInterstitial(filteredReason) => + FilteredReasonHelper.getLocalizedSuppressedReasonInterstitial(filteredReason) match { + case Some(t.Interstitial(reasonOpt, Some(message))) => + InterstitialReasonConverter.fromThrift(reasonOpt).map { interstitialReason => + interstitialLocalStatsReceiver.counter("interstitial").incr() + Interstitial( + Reason.fromInterstitialReason(interstitialReason), + Some(LocalizedMessageConverter.fromThrift(message))) + } + + case _ => None + } + + case _ if FilteredReasonHelper.containNsfwMedia(filteredReason) => + None + + case _ if FilteredReasonHelper.possiblyUndesirable(filteredReason) => + None + + case _ if FilteredReasonHelper.reportedTweet(filteredReason) => + filteredReason match { + case FilteredReason.ReportedTweet(true) => + interstitialLegacyStatsReceiver.counter("fr_reported").incr() + Some(Interstitial(Reason.ViewerReportedAuthor)) + + case FilteredReason.SafetyResult(safetyResult: SafetyResult) + if enableLegacyInterstitials => + val safetyResultReported = InterstitialReasonConverter + .fromAction(safetyResult.action).collect { + case InterstitialReason.ViewerReportedTweet => true + case InterstitialReason.ViewerReportedAuthor => true + }.getOrElse(false) + + if (safetyResultReported) { + interstitialLegacyStatsReceiver.counter("reported_author").incr() + Some(Interstitial(Reason.ViewerReportedAuthor)) + } else None + + case _ => None + } + + case _ if FilteredReasonHelper.tweetMatchesViewerMutedKeyword(filteredReason) => + filteredReason match { + case FilteredReason.TweetMatchesViewerMutedKeyword(_) => + interstitialLegacyStatsReceiver.counter("fr_muted_keyword").incr() + Some(Interstitial(Reason.MutedKeyword)) + + case FilteredReason.SafetyResult(safetyResult: SafetyResult) + if enableLegacyInterstitials => + val safetyResultMutedKeyword = InterstitialReasonConverter + .fromAction(safetyResult.action).collect { + case _: InterstitialReason.MatchesMutedKeyword => true + }.getOrElse(false) + + if (safetyResultMutedKeyword) { + interstitialLegacyStatsReceiver.counter("muted_keyword").incr() + Some(Interstitial(Reason.MutedKeyword)) + } else None + + case _ => None + } + + case _ => + None + } + } + + def toAction( + tfrs: TweetFieldsResultState, + actionStatsReceiver: StatsReceiver + ): Option[Action] = { + + val enableLocalizedInterstitials = visibilityDeciderGates.enableConvosLocalizedInterstitial() + val enableLegacyInterstitials = visibilityDeciderGates.enableConvosLegacyInterstitial() + + val tombstoneStatsReceiver = actionStatsReceiver.scope("tombstone") + val interstitialLocalStatsReceiver = + actionStatsReceiver.scope("interstitial").scope("localized") + val interstitialLegacyStatsReceiver = + actionStatsReceiver.scope("interstitial").scope("legacy") + + tfrs match { + + case TweetFieldsResultState.NotFound(TweetFieldsResultNotFound(_, _, Some(filteredReason))) + if FilteredReasonHelper.isTombstone(filteredReason) => + createLocalizedTombstone(filteredReason, tombstoneStatsReceiver) + + case TweetFieldsResultState.NotFound(tfr: TweetFieldsResultNotFound) if tfr.deleted => + createTombstone(Epitaph.Deleted, tombstoneStatsReceiver) + + case TweetFieldsResultState.NotFound(_: TweetFieldsResultNotFound) => + createTombstone(Epitaph.NotFound, tombstoneStatsReceiver) + + case TweetFieldsResultState.Failed(TweetFieldsResultFailed(_, _, _)) => + createTombstone(Epitaph.Unavailable, tombstoneStatsReceiver, Some("failed")) + + case TweetFieldsResultState.Filtered(TweetFieldsResultFiltered(UnspecifiedReason(true))) => + createTombstone(Epitaph.Unavailable, tombstoneStatsReceiver, Some("filtered")) + + case TweetFieldsResultState.Filtered(TweetFieldsResultFiltered(filteredReason)) => + toAction(filteredReason, actionStatsReceiver) + + case TweetFieldsResultState.Found(TweetFieldsResultFound(_, _, Some(filteredReason))) + if enableLocalizedInterstitials && + FilteredReasonHelper.isSuppressedReasonPublicInterestInterstial(filteredReason) => + interstitialLocalStatsReceiver.counter("ipi").incr() + FilteredReasonHelper + .getSafetyResult(filteredReason) + .flatMap(_.reason) + .flatMap(PublicInterest.SafetyResultReasonToReason.get) match { + case Some(safetyResultReason) => + FilteredReasonHelper + .getSuppressedReasonPublicInterestInterstial(filteredReason) + .map(edi => edi.localizedMessage) + .map(tlm => LocalizedMessageConverter.fromThrift(tlm)) + .map(lm => + InterstitialLimitedEngagements( + safetyResultReason, + Some(LimitedEngagementReason.NonCompliant), + lm)) + case _ => None + } + + case TweetFieldsResultState.Found(TweetFieldsResultFound(_, _, Some(filteredReason))) + if enableLegacyInterstitials && + FilteredReasonHelper.isSuppressedReasonPublicInterestInterstial(filteredReason) => + interstitialLegacyStatsReceiver.counter("ipi").incr() + FilteredReasonHelper + .getSafetyResult(filteredReason) + .flatMap(_.reason) + .flatMap(PublicInterest.SafetyResultReasonToReason.get) + .map(InterstitialLimitedEngagements(_, Some(LimitedEngagementReason.NonCompliant))) + + case TweetFieldsResultState.Found(TweetFieldsResultFound(_, _, Some(filteredReason))) + if enableLocalizedInterstitials && + FilteredReasonHelper.isLocalizedSuppressedReasonEmergencyDynamicInterstitial( + filteredReason) => + interstitialLocalStatsReceiver.counter("edi").incr() + FilteredReasonHelper + .getSuppressedReasonEmergencyDynamicInterstitial(filteredReason) + .map(e => + EmergencyDynamicInterstitial( + e.copy, + e.link, + LocalizedMessageConverter.fromThrift(e.localizedMessage))) + + case TweetFieldsResultState.Found(TweetFieldsResultFound(_, _, Some(filteredReason))) + if enableLegacyInterstitials && + FilteredReasonHelper.isSuppressedReasonEmergencyDynamicInterstitial(filteredReason) => + interstitialLegacyStatsReceiver.counter("edi").incr() + FilteredReasonHelper + .getSuppressedReasonEmergencyDynamicInterstitial(filteredReason) + .map(e => EmergencyDynamicInterstitial(e.copy, e.link)) + + case TweetFieldsResultState.Found(TweetFieldsResultFound(tweet, _, _)) + if tweet.perspective.exists(_.reported) => + interstitialLegacyStatsReceiver.counter("reported").incr() + Some(Interstitial(Reason.ViewerReportedAuthor)) + + case TweetFieldsResultState.Found( + TweetFieldsResultFound(_, _, Some(UnspecifiedReason(true)))) => + None + + case TweetFieldsResultState.Found(TweetFieldsResultFound(_, _, Some(filteredReason))) => + toAction(filteredReason, actionStatsReceiver) + + case _ => + None + } + } + + private[conversations] def shouldTruncateDescendantsWhenFocal(action: Action): Boolean = + action match { + case _: InterstitialLimitedEngagements | _: EmergencyDynamicInterstitial => + true + case Tombstone(Epitaph.Bounced, _) | Tombstone(Epitaph.BounceDeleted, _) => + true + case LocalizedTombstone(TombstoneReason.Bounced, _) | + LocalizedTombstone(TombstoneReason.BounceDeleted, _) => + true + case LimitedEngagements(LimitedEngagementReason.NonCompliant, _) => + true + case _ => false + } + + def apply(request: TombstoneVisibilityRequest): Stitch[TombstoneVisibilityResponse] = { + + val moderationFeatures = new ModerationFeatures( + moderationSource = request.moderatedTweetIds.contains, + statsReceiver = statsReceiver + ) + + val userSource = UserSource.fromFunction({ + case (userId, _) => + request.authorMap + .get(userId) + .map(Stitch.value).getOrElse(Stitch.NotFound) + }) + + val authorFeatures = new AuthorFeatures(userSource, statsReceiver) + val viewerFeatures = new ViewerFeatures(userSource, statsReceiver) + + val languageTag = request.viewerContext.requestCountryCode.getOrElse(En) + val firstRound: Seq[(GetTweetFieldsResult, Option[TombstoneType])] = request.tweets.map { + case (gtfr, safetyLevel) => + val actionStats = statsReceiver + .scope("action") + .scope(safetyLevel.map(_.toString().toLowerCase()).getOrElse("unknown_safety_level")) + toAction(gtfr.tweetResult, actionStats) match { + case Some(action) => + (gtfr, Some(TombstoneType(gtfr.tweetId, gtfr.tweetId, action))) + + case None => + val quotedTweetId: Option[Long] = gtfr.tweetResult match { + case TweetFieldsResultState.Found(TweetFieldsResultFound(tweet, _, _)) => + tweet.quotedTweet.map(_.tweetId) + case _ => None + } + + (quotedTweetId, gtfr.quotedTweetResult) match { + case (Some(quotedTweetId), Some(tfrs)) => + val qtActionStats = actionStats.scope("quoted") + toAction(tfrs, qtActionStats) match { + case None => + (gtfr, None) + + case Some(action) => + (gtfr, Some(TombstoneType(gtfr.tweetId, quotedTweetId, action))) + } + + case _ => + (gtfr, None) + } + } + } + + val (firstRoundActions, secondRoundInput) = firstRound.partition { + case (_, Some(tombstoneType)) => + !tombstoneType.isInnerTombstone + case (_, None) => false + } + + def invokeVisibilityLibrary(tweetId: Long, author: User): Stitch[Action] = { + visibilityLibrary + .runRuleEngine( + ContentId.TweetId(tweetId), + visibilityLibrary.featureMapBuilder( + Seq( + viewerFeatures.forViewerContext(request.viewerContext), + moderationFeatures.forTweetId(tweetId), + authorFeatures.forAuthor(author), + relationshipFeatures + .forAuthor(author, request.viewerContext.userId), + _.withConstantFeature(TweetId, tweetId), + _.withConstantFeature(FocalTweetId, request.focalTweetId) + ) + ), + request.viewerContext, + Tombstoning + ).map(_.verdict) + } + + val secondRoundActions: Stitch[Seq[(GetTweetFieldsResult, Option[TombstoneType])]] = + Stitch.traverse(secondRoundInput) { + case (gtfr: GetTweetFieldsResult, firstRoundTombstone: Option[TombstoneType]) => + val secondRoundTombstone: Stitch[Option[TombstoneType]] = gtfr.tweetResult match { + case TweetFieldsResultState.Found(TweetFieldsResultFound(tweet, _, _)) => + val tweetId = tweet.id + + tweet.coreData + .flatMap { coreData => request.authorMap.get(coreData.userId) } match { + case Some(author) => + invokeVisibilityLibrary(tweetId, author).flatMap { + case Allow => + val quotedTweetId = tweet.quotedTweet.map(_.tweetId) + val quotedTweetAuthor = tweet.quotedTweet.flatMap { qt => + request.authorMap.get(qt.userId) + } + + (quotedTweetId, quotedTweetAuthor) match { + case (Some(quotedTweetId), Some(quotedTweetAuthor)) => + invokeVisibilityLibrary(quotedTweetId, quotedTweetAuthor).flatMap { + case Allow => + Stitch.None + + case reason => + Stitch.value(Some(TombstoneType(tweetId, quotedTweetId, reason))) + } + + case _ => + Stitch.None + } + + case reason => + Stitch.value(Some(TombstoneType(tweetId, tweetId, reason))) + } + + case None => + Stitch.None + } + + case _ => + Stitch.None + } + + secondRoundTombstone.map { opt => opt.orElse(firstRoundTombstone) }.map { opt => + (gtfr, opt) + } + } + + secondRoundActions.map { secondRound => + val tombstones: Seq[(Long, VfTombstone)] = (firstRoundActions ++ secondRound).flatMap { + case (gtfr, tombstoneTypeOpt) => { + + val nonCompliantLimitedEngagementsOpt = gtfr.tweetResult match { + case TweetFieldsResultState.Found(TweetFieldsResultFound(_, _, Some(filteredReason))) + if FilteredReasonHelper.isLimitedEngagementsNonCompliant(filteredReason) => + Some(LimitedEngagements(LimitedEngagementReason.NonCompliant)) + case _ => None + } + + (tombstoneTypeOpt, nonCompliantLimitedEngagementsOpt) match { + case (Some(tombstoneType), nonCompliantOpt) => + val tombstoneId = tombstoneType.tombstoneId + val action = tombstoneType.action + val textOpt: Option[RichText] = action match { + + case InterstitialLimitedEngagements(_, _, Some(localizedMessage), _) => + Some(LocalizedMessageToRichText(localizedMessage)) + case ipi: InterstitialLimitedEngagements => + Some( + SafetyResultToUrtParser.fromSafetyResultToRichText( + SafetyResult( + Some(PublicInterest.ReasonToSafetyResultReason(ipi.reason)), + ipi.toActionThrift() + ), + languageTag + ) + ) + + case EmergencyDynamicInterstitial(_, _, Some(localizedMessage), _) => + Some(LocalizedMessageToRichText(localizedMessage)) + case edi: EmergencyDynamicInterstitial => + Some( + SafetyResultToUrtParser.fromSafetyResultToRichText( + SafetyResult( + None, + edi.toActionThrift() + ), + languageTag + ) + ) + + case Tombstone(epitaph, _) => + if (request.useRichText) + Some(EpitaphToRichText(epitaph, languageTag)) + else + Some(EpitaphToRichText(Epitaph.UnavailableWithoutLink, languageTag)) + + case LocalizedTombstone(_, message) => + if (request.useRichText) + Some(LocalizedMessageToRichText(LocalizedMessageConverter.toThrift(message))) + else + Some(EpitaphToRichText(Epitaph.UnavailableWithoutLink, languageTag)) + + case Interstitial(_, Some(localizedMessage), _) => + Some(LocalizedMessageToRichText.apply(localizedMessage)) + + case interstitial: Interstitial => + ReasonToUrtParser.fromReasonToRichText(interstitial.reason, languageTag) + + case _ => + None + } + + val isRoot: Boolean = gtfr.tweetId == request.conversationId + val isOuter: Boolean = tombstoneId == request.conversationId + val revealTextOpt: Option[RichText] = action match { + case _: InterstitialLimitedEngagements | _: EmergencyDynamicInterstitial + if isRoot && isOuter => + None + + case _: Interstitial | _: InterstitialLimitedEngagements | + _: EmergencyDynamicInterstitial => + Some(ReasonToUrtParser.getRichRevealText(languageTag)) + + case _ => + None + } + + val includeTweet = action match { + case _: Interstitial | _: InterstitialLimitedEngagements | + _: EmergencyDynamicInterstitial => + true + case _ => false + } + + val truncateForAction: Boolean = + shouldTruncateDescendantsWhenFocal(action) + val truncateForNonCompliant: Boolean = + nonCompliantOpt + .map(shouldTruncateDescendantsWhenFocal).getOrElse(false) + val truncateDescendants: Boolean = + truncateForAction || truncateForNonCompliant + + val tombstone = textOpt match { + case Some(_) if request.useRichText => + VfTombstone( + includeTweet = includeTweet, + action = action, + tombstoneInfo = Some( + TombstoneInfo( + cta = None, + revealText = None, + richText = textOpt, + richRevealText = revealTextOpt + ) + ), + tombstoneDisplayType = tombstoneType.tombstoneDisplayType, + truncateDescendantsWhenFocal = truncateDescendants + ) + case Some(_) => + VfTombstone( + includeTweet = includeTweet, + action = action, + tombstoneInfo = Some( + TombstoneInfo( + text = textOpt + .map(richText => richText.text).getOrElse( + "" + cta = None, + revealText = revealTextOpt.map(_.text), + richText = None, + richRevealText = None + ) + ), + tombstoneDisplayType = tombstoneType.tombstoneDisplayType, + truncateDescendantsWhenFocal = truncateDescendants + ) + + case None => + VfTombstone( + includeTweet = false, + action = action, + tombstoneInfo = Some( + TombstoneInfo( + cta = None, + revealText = None, + richText = Some(EpitaphToRichText(Epitaph.Unavailable, languageTag)), + richRevealText = None + ) + ), + tombstoneDisplayType = tombstoneType.tombstoneDisplayType, + truncateDescendantsWhenFocal = truncateDescendants + ) + } + + Some((gtfr.tweetId, tombstone)) + + case (None, Some(limitedEngagements)) + if shouldTruncateDescendantsWhenFocal(limitedEngagements) => + val tombstone = VfTombstone( + tombstoneId = gtfr.tweetId, + includeTweet = true, + action = limitedEngagements, + tombstoneInfo = None, + tombstoneDisplayType = TombstoneDisplayType.NonCompliant, + truncateDescendantsWhenFocal = true + ) + Some((gtfr.tweetId, tombstone)) + + case _ => + None + } + } + } + + TombstoneVisibilityResponse( + tweetVerdicts = tombstones.toMap + ) + } + } + + private def createLocalizedTombstone( + filteredReason: FilteredReason, + tombstoneStats: StatsReceiver, + ): Option[LocalizedTombstone] = { + + val tombstoneOpt = FilteredReasonHelper.getTombstone(filteredReason) + tombstoneOpt match { + case Some(t.Tombstone(reasonOpt, Some(message))) => + TombstoneReasonConverter.fromThrift(reasonOpt).map { localReason => + tombstoneStats + .scope("localized").counter(localReason.toString().toLowerCase()).incr() + LocalizedTombstone(localReason, LocalizedMessageConverter.fromThrift(message)) + } + + case _ => None + } + } + + private def createTombstone( + epitaph: Epitaph, + tombstoneStats: StatsReceiver, + extraCounterOpt: Option[String] = None + ): Option[Action] = { + tombstoneStats + .scope("legacy") + .counter(epitaph.toString().toLowerCase()) + .incr() + extraCounterOpt.map { extraCounter => + tombstoneStats + .scope("legacy") + .scope(epitaph.toString().toLowerCase()) + .counter(extraCounter) + .incr() + } + Some(Tombstone(epitaph)) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/package.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/package.scala new file mode 100644 index 000000000..4064fc33b --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/conversations/package.scala @@ -0,0 +1,11 @@ +package com.twitter.visibility.interfaces + +import com.twitter.servo.repository.KeyValueRepository +import com.twitter.spam.rtf.thriftscala.SafetyLabel +import com.twitter.spam.rtf.thriftscala.SafetyLabelType +import scala.collection.Map + +package object conversations { + type BatchSafetyLabelRepository = + KeyValueRepository[(Long, Seq[Long]), Long, Map[SafetyLabelType, SafetyLabel]] +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/des/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/des/BUILD new file mode 100644 index 000000000..2d2dfacf8 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/des/BUILD @@ -0,0 +1,23 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "src/thrift/com/twitter/spam/rtf:tweet-rtf-event-scala", + "src/thrift/com/twitter/tweetypie:events-scala", + "src/thrift/com/twitter/tweetypie:tweet-scala", + "stitch/stitch-core", + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/common/src/main/thrift/com/twitter/visibility:action-scala", + "visibility/lib/src/main/resources/config", + "visibility/lib/src/main/scala/com/twitter/visibility", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/tweets", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/users", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/common/tweets", + ], + exports = [ + "visibility/lib/src/main/scala/com/twitter/visibility", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/des/DESRealtimeVisibilityLibrary.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/des/DESRealtimeVisibilityLibrary.scala new file mode 100644 index 000000000..dd6cc68de --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/des/DESRealtimeVisibilityLibrary.scala @@ -0,0 +1,99 @@ +package com.twitter.visibility.interfaces.des + +import com.twitter.gizmoduck.thriftscala.User +import com.twitter.stitch.Stitch +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.visibility.VisibilityLibrary +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.builder.tweets.CommunityTweetFeaturesV2 +import com.twitter.visibility.builder.tweets.EditTweetFeatures +import com.twitter.visibility.builder.tweets.ExclusiveTweetFeatures +import com.twitter.visibility.builder.tweets.NilTweetLabelMaps +import com.twitter.visibility.builder.tweets.TrustedFriendsFeatures +import com.twitter.visibility.builder.tweets.TweetFeatures +import com.twitter.visibility.builder.users.AuthorFeatures +import com.twitter.visibility.builder.users.ViewerFeatures +import com.twitter.visibility.common.CommunitiesSource +import com.twitter.visibility.common.TrustedFriendsSource +import com.twitter.visibility.common.UserRelationshipSource +import com.twitter.visibility.common.UserSource +import com.twitter.visibility.models.ContentId +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.ViewerContext +import com.twitter.visibility.rules.Allow +import com.twitter.visibility.{thriftscala => vfthrift} + +case class DESRealtimeVisibilityRequest(tweet: Tweet, author: User, viewer: Option[User]) + +object DESRealtimeVisibilityLibrary { + type Type = DESRealtimeVisibilityRequest => Stitch[vfthrift.Action] + + private[this] val safetyLevel = SafetyLevel.DesRealtime + + def apply(visibilityLibrary: VisibilityLibrary): Type = { + val libraryStatsReceiver = visibilityLibrary.statsReceiver + val vfEngineCounter = libraryStatsReceiver.counter("vf_engine_requests") + + val tweetFeatures = new TweetFeatures(NilTweetLabelMaps, libraryStatsReceiver) + + val authorFeatures = new AuthorFeatures(UserSource.empty, libraryStatsReceiver) + val viewerFeatures = new ViewerFeatures(UserSource.empty, libraryStatsReceiver) + val communityTweetFeatures = new CommunityTweetFeaturesV2(CommunitiesSource.empty) + val exclusiveTweetFeatures = + new ExclusiveTweetFeatures(UserRelationshipSource.empty, libraryStatsReceiver) + val trustedFriendsTweetFeatures = new TrustedFriendsFeatures(TrustedFriendsSource.empty) + val editTweetFeatures = new EditTweetFeatures(libraryStatsReceiver) + + { request: DESRealtimeVisibilityRequest => + vfEngineCounter.incr() + + val tweet = request.tweet + val author = request.author + val viewer = request.viewer + val viewerContext = ViewerContext.fromContext + + val featureMap = + visibilityLibrary.featureMapBuilder( + Seq( + tweetFeatures.forTweetWithoutSafetyLabels(tweet), + authorFeatures.forAuthorNoDefaults(author), + viewerFeatures.forViewerNoDefaults(viewer), + communityTweetFeatures.forTweetOnly(tweet), + exclusiveTweetFeatures.forTweetOnly(tweet), + trustedFriendsTweetFeatures.forTweetOnly(tweet), + editTweetFeatures.forTweet(tweet), + ) + ) + + val tweetResult = visibilityLibrary.runRuleEngine( + ContentId.TweetId(tweet.id), + featureMap, + viewerContext, + safetyLevel + ) + val authorResult = visibilityLibrary.runRuleEngine( + ContentId.UserId(author.id), + featureMap, + viewerContext, + safetyLevel + ) + + Stitch.join(tweetResult, authorResult).map { + case (tweetResult, authorResult) => mergeResults(tweetResult, authorResult) + } + } + } + + def mergeResults( + tweetResult: VisibilityResult, + authorResult: VisibilityResult, + ): vfthrift.Action = { + Set(tweetResult.verdict, authorResult.verdict) + .find { + case Allow => false + case _ => true + } + .map(_.toActionThrift()) + .getOrElse(Allow.toActionThrift()) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/des/DESVisibilityLibrary.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/des/DESVisibilityLibrary.scala new file mode 100644 index 000000000..b3297c67c --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/des/DESVisibilityLibrary.scala @@ -0,0 +1,72 @@ +package com.twitter.visibility.interfaces.des + +import com.twitter.stitch.Stitch +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.visibility.VisibilityLibrary +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.builder.tweets.StratoTweetLabelMaps +import com.twitter.visibility.builder.tweets.TweetFeatures +import com.twitter.visibility.common.SafetyLabelMapSource +import com.twitter.visibility.features.AuthorId +import com.twitter.visibility.features.FeatureMap +import com.twitter.visibility.interfaces.common.tweets.SafetyLabelMapFetcherType +import com.twitter.visibility.models.ContentId.TweetId +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.ViewerContext + +case class DESVisibilityRequest( + tweet: Tweet, + visibilitySurface: SafetyLevel, + viewerContext: ViewerContext) + +object DESVisibilityLibrary { + type Type = DESVisibilityRequest => Stitch[VisibilityResult] + + def apply( + visibilityLibrary: VisibilityLibrary, + getLabelMap: SafetyLabelMapFetcherType, + enableShimFeatureHydration: Any => Boolean = _ => false + ): Type = { + val libraryStatsReceiver = visibilityLibrary.statsReceiver + val vfEngineCounter = libraryStatsReceiver.counter("vf_engine_requests") + + val tweetLabelMap = new StratoTweetLabelMaps( + SafetyLabelMapSource.fromSafetyLabelMapFetcher(getLabelMap)) + val tweetFeatures = new TweetFeatures(tweetLabelMap, libraryStatsReceiver) + + { request: DESVisibilityRequest => + vfEngineCounter.incr() + + val contentId = TweetId(request.tweet.id) + val authorId = coreData.userId + + val featureMap = + visibilityLibrary.featureMapBuilder( + Seq( + tweetFeatures.forTweet(request.tweet), + _.withConstantFeature(AuthorId, Set(authorId)) + ) + ) + + val isShimFeatureHydrationEnabled = enableShimFeatureHydration() + + if (isShimFeatureHydrationEnabled) { + FeatureMap.resolve(featureMap, libraryStatsReceiver).flatMap { resolvedFeatureMap => + visibilityLibrary.runRuleEngine( + contentId, + resolvedFeatureMap, + request.viewerContext, + request.visibilitySurface + ) + } + } else { + visibilityLibrary.runRuleEngine( + contentId, + featureMap, + request.viewerContext, + request.visibilitySurface + ) + } + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/dms/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/dms/BUILD new file mode 100644 index 000000000..5a3dcb977 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/dms/BUILD @@ -0,0 +1,32 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "decider/src/main/scala", + "stitch/stitch-core", + "strato/src/main/scala/com/twitter/strato/catalog", + "strato/src/main/scala/com/twitter/strato/client", + "strato/src/main/scala/com/twitter/strato/data", + "strato/src/main/scala/com/twitter/strato/thrift", + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/common/src/main/scala/com/twitter/visibility/common/dm_sources", + "visibility/common/src/main/scala/com/twitter/visibility/common/stitch", + "visibility/lib/src/main/resources/config", + "visibility/lib/src/main/scala/com/twitter/visibility", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/dms", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/users", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/configs", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/rules/providers", + "visibility/lib/src/main/scala/com/twitter/visibility/rules/utils", + "visibility/lib/src/main/scala/com/twitter/visibility/util", + "visibility/lib/src/main/thrift/com/twitter/visibility/safety_label_store:safety-label-store-scala", + ], + exports = [ + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/lib/src/main/scala/com/twitter/visibility", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/dms/DmConversationVisibilityLibrary.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/dms/DmConversationVisibilityLibrary.scala new file mode 100644 index 000000000..23d547c3e --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/dms/DmConversationVisibilityLibrary.scala @@ -0,0 +1,94 @@ +package com.twitter.visibility.interfaces.dms + +import com.twitter.servo.util.Gate +import com.twitter.stitch.Stitch +import com.twitter.strato.client.{Client => StratoClient} +import com.twitter.visibility.VisibilityLibrary +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.builder.dms.DmConversationFeatures +import com.twitter.visibility.builder.users.AuthorFeatures +import com.twitter.visibility.common.UserSource +import com.twitter.visibility.common.dm_sources.DmConversationSource +import com.twitter.visibility.common.stitch.StitchHelpers +import com.twitter.visibility.features.FeatureMap +import com.twitter.visibility.models.ContentId.DmConversationId +import com.twitter.visibility.rules.Drop +import com.twitter.visibility.rules.EvaluationContext +import com.twitter.visibility.rules.Reason +import com.twitter.visibility.rules.RuleBase +import com.twitter.visibility.rules.providers.ProvidedEvaluationContext +import com.twitter.visibility.rules.utils.ShimUtils + +object DmConversationVisibilityLibrary { + type Type = DmConversationVisibilityRequest => Stitch[VisibilityResult] + + def apply( + visibilityLibrary: VisibilityLibrary, + stratoClient: StratoClient, + userSource: UserSource, + enableVfFeatureHydrationInShim: Gate[Unit] = Gate.False + ): Type = { + val libraryStatsReceiver = visibilityLibrary.statsReceiver + val stratoClientStatsReceiver = visibilityLibrary.statsReceiver.scope("strato") + val vfLatencyStatsReceiver = visibilityLibrary.statsReceiver.scope("vf_latency") + val vfEngineCounter = libraryStatsReceiver.counter("vf_engine_requests") + + val dmConversationSource = + DmConversationSource.fromStrato(stratoClient, stratoClientStatsReceiver) + val authorFeatures = new AuthorFeatures(userSource, libraryStatsReceiver) + val dmConversationFeatures = new DmConversationFeatures(dmConversationSource, authorFeatures) + + { req: DmConversationVisibilityRequest => + val dmConversationId = req.dmConversationId + val contentId = DmConversationId(dmConversationId) + val safetyLevel = req.safetyLevel + + if (!RuleBase.hasDmConversationRules(safetyLevel)) { + Stitch.value(VisibilityResult(contentId = contentId, verdict = Drop(Reason.Unspecified))) + } else { + vfEngineCounter.incr() + + val viewerContext = req.viewerContext + val viewerId = viewerContext.userId + val isVfFeatureHydrationEnabled: Boolean = + enableVfFeatureHydrationInShim() + + val featureMap = visibilityLibrary.featureMapBuilder( + Seq(dmConversationFeatures.forDmConversationId(dmConversationId, viewerId))) + + val resp = if (isVfFeatureHydrationEnabled) { + val evaluationContext = ProvidedEvaluationContext.injectRuntimeRulesIntoEvaluationContext( + evaluationContext = EvaluationContext( + safetyLevel, + visibilityLibrary.getParams(viewerContext, safetyLevel), + visibilityLibrary.statsReceiver) + ) + + val preFilteredFeatureMap = + ShimUtils.preFilterFeatureMap(featureMap, safetyLevel, contentId, evaluationContext) + + FeatureMap.resolve(preFilteredFeatureMap, libraryStatsReceiver).flatMap { + resolvedFeatureMap => + visibilityLibrary + .runRuleEngine( + contentId, + resolvedFeatureMap, + viewerContext, + safetyLevel + ) + } + } else { + visibilityLibrary + .runRuleEngine( + contentId, + featureMap, + viewerContext, + safetyLevel + ) + } + + StitchHelpers.profileStitch(resp, Seq(vfLatencyStatsReceiver)) + } + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/dms/DmConversationVisibilityRequest.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/dms/DmConversationVisibilityRequest.scala new file mode 100644 index 000000000..0b3eac66c --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/dms/DmConversationVisibilityRequest.scala @@ -0,0 +1,9 @@ +package com.twitter.visibility.interfaces.dms + +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.ViewerContext + +case class DmConversationVisibilityRequest( + dmConversationId: String, + safetyLevel: SafetyLevel, + viewerContext: ViewerContext) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/dms/DmEventVisibilityLibrary.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/dms/DmEventVisibilityLibrary.scala new file mode 100644 index 000000000..d66539459 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/dms/DmEventVisibilityLibrary.scala @@ -0,0 +1,80 @@ +package com.twitter.visibility.interfaces.dms + +import com.twitter.stitch.Stitch +import com.twitter.strato.client.{Client => StratoClient} +import com.twitter.visibility.VisibilityLibrary +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.builder.dms.DmConversationFeatures +import com.twitter.visibility.builder.dms.DmEventFeatures +import com.twitter.visibility.builder.dms.InvalidDmEventFeatureException +import com.twitter.visibility.builder.users.AuthorFeatures +import com.twitter.visibility.common.UserSource +import com.twitter.visibility.common.dm_sources.DmConversationSource +import com.twitter.visibility.common.dm_sources.DmEventSource +import com.twitter.visibility.common.stitch.StitchHelpers +import com.twitter.visibility.models.ContentId.DmEventId +import com.twitter.visibility.rules.Drop +import com.twitter.visibility.rules.Reason +import com.twitter.visibility.rules.RuleBase + +object DmEventVisibilityLibrary { + type Type = DmEventVisibilityRequest => Stitch[VisibilityResult] + + def apply( + visibilityLibrary: VisibilityLibrary, + stratoClient: StratoClient, + userSource: UserSource + ): Type = { + val libraryStatsReceiver = visibilityLibrary.statsReceiver + val stratoClientStatsReceiver = visibilityLibrary.statsReceiver.scope("strato") + val vfLatencyStatsReceiver = visibilityLibrary.statsReceiver.scope("vf_latency") + val vfEngineCounter = libraryStatsReceiver.counter("vf_engine_requests") + val dmConversationSource = { + DmConversationSource.fromStrato(stratoClient, stratoClientStatsReceiver) + } + val dmEventSource = { + DmEventSource.fromStrato(stratoClient, stratoClientStatsReceiver) + } + val authorFeatures = new AuthorFeatures(userSource, libraryStatsReceiver) + val dmConversationFeatures = new DmConversationFeatures(dmConversationSource, authorFeatures) + val dmEventFeatures = + new DmEventFeatures( + dmEventSource, + dmConversationSource, + authorFeatures, + dmConversationFeatures, + libraryStatsReceiver) + + { req: DmEventVisibilityRequest => + val dmEventId = req.dmEventId + val contentId = DmEventId(dmEventId) + val safetyLevel = req.safetyLevel + + if (!RuleBase.hasDmEventRules(safetyLevel)) { + Stitch.value(VisibilityResult(contentId = contentId, verdict = Drop(Reason.Unspecified))) + } else { + vfEngineCounter.incr() + + val viewerContext = req.viewerContext + val viewerIdOpt = viewerContext.userId + + viewerIdOpt match { + case Some(viewerId) => + val featureMap = visibilityLibrary.featureMapBuilder( + Seq(dmEventFeatures.forDmEventId(dmEventId, viewerId))) + + val resp = visibilityLibrary + .runRuleEngine( + contentId, + featureMap, + viewerContext, + safetyLevel + ) + StitchHelpers.profileStitch(resp, Seq(vfLatencyStatsReceiver)) + + case None => Stitch.exception(InvalidDmEventFeatureException("Viewer id is missing")) + } + } + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/dms/DmEventVisibilityRequest.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/dms/DmEventVisibilityRequest.scala new file mode 100644 index 000000000..c7e9c65af --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/dms/DmEventVisibilityRequest.scala @@ -0,0 +1,9 @@ +package com.twitter.visibility.interfaces.dms + +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.ViewerContext + +case class DmEventVisibilityRequest( + dmEventId: Long, + safetyLevel: SafetyLevel, + viewerContext: ViewerContext) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/dms/DmVisibilityLibrary.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/dms/DmVisibilityLibrary.scala new file mode 100644 index 000000000..d6a7d3ce5 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/dms/DmVisibilityLibrary.scala @@ -0,0 +1,88 @@ +package com.twitter.visibility.interfaces.dms + +import com.twitter.servo.util.Gate +import com.twitter.stitch.Stitch +import com.twitter.strato.client.{Client => StratoClient} +import com.twitter.visibility.VisibilityLibrary +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.builder.users.AuthorFeatures +import com.twitter.visibility.common.DmId +import com.twitter.visibility.common.UserId +import com.twitter.visibility.common.UserSource +import com.twitter.visibility.features.FeatureMap +import com.twitter.visibility.models.ContentId.{DmId => DmContentId} +import com.twitter.visibility.models.SafetyLevel.DirectMessages +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.ViewerContext +import com.twitter.visibility.rules.Drop +import com.twitter.visibility.rules.Reason.DeactivatedAuthor +import com.twitter.visibility.rules.Reason.ErasedAuthor +import com.twitter.visibility.rules.Reason.Nsfw + +object DmVisibilityLibrary { + type Type = DmVisibilityRequest => Stitch[DmVisibilityResponse] + + case class DmVisibilityRequest( + dmId: DmId, + dmAuthorUserId: UserId, + viewerContext: ViewerContext) + + case class DmVisibilityResponse(isMessageNsfw: Boolean) + + val DefaultSafetyLevel: SafetyLevel = DirectMessages + + def apply( + visibilityLibrary: VisibilityLibrary, + stratoClient: StratoClient, + userSource: UserSource, + enableVfFeatureHydrationInShim: Gate[Unit] = Gate.False + ): Type = { + val libraryStatsReceiver = visibilityLibrary.statsReceiver + val vfEngineCounter = libraryStatsReceiver.counter("vf_engine_requests") + + val authorFeatures = new AuthorFeatures(userSource, libraryStatsReceiver) + + { r: DmVisibilityRequest => + vfEngineCounter.incr() + + val contentId = DmContentId(r.dmId) + val dmAuthorUserId = r.dmAuthorUserId + val isVfFeatureHydrationEnabled = enableVfFeatureHydrationInShim() + + val featureMap = + visibilityLibrary.featureMapBuilder( + Seq(authorFeatures.forAuthorId(dmAuthorUserId)) + ) + + val resp = if (isVfFeatureHydrationEnabled) { + FeatureMap.resolve(featureMap, libraryStatsReceiver).flatMap { resolvedFeatureMap => + visibilityLibrary.runRuleEngine( + contentId, + resolvedFeatureMap, + r.viewerContext, + DefaultSafetyLevel + ) + } + } else { + visibilityLibrary + .runRuleEngine( + contentId, + featureMap, + r.viewerContext, + DefaultSafetyLevel + ) + } + + resp.map(buildResponse) + } + } + + private[this] def buildResponse(visibilityResult: VisibilityResult) = + visibilityResult.verdict match { + case Drop(Nsfw | ErasedAuthor | DeactivatedAuthor, _) => + DmVisibilityResponse(isMessageNsfw = true) + case _ => + DmVisibilityResponse(isMessageNsfw = false) + } + +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/dms/package.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/dms/package.scala new file mode 100644 index 000000000..0b8f4b410 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/dms/package.scala @@ -0,0 +1,12 @@ +package com.twitter.visibility.interfaces + +import com.twitter.stitch.Stitch +import com.twitter.visibility.common.DmId +import com.twitter.visibility.safety_label_store.thriftscala.DmSafetyLabelMap + +package object dms { + type DmSafetyLabelMapFetcherType = DmId => Stitch[Option[DmSafetyLabelMap]] + + val DmSafetyLabelMapFetcherStratoColumn = + "visibility/safety-label-store/vflib/dm/safetyLabelMap.Dm" +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/media/BUILD.bazel b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/media/BUILD.bazel new file mode 100644 index 000000000..1fe0d3dd8 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/media/BUILD.bazel @@ -0,0 +1,16 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + strict_deps = False, + tags = ["bazel-compatible"], + dependencies = [ + "mediaservices/media-util", + "stitch/stitch-core", + "visibility/lib/src/main/scala/com/twitter/visibility", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/media", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/users", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/rules/providers", + "visibility/lib/src/main/scala/com/twitter/visibility/rules/utils", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/media/MediaVisibilityLibrary.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/media/MediaVisibilityLibrary.scala new file mode 100644 index 000000000..79f25d3ea --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/media/MediaVisibilityLibrary.scala @@ -0,0 +1,89 @@ +package com.twitter.visibility.interfaces.media + +import com.twitter.stitch.Stitch +import com.twitter.strato.client.{Client => StratoClient} +import com.twitter.util.Stopwatch +import com.twitter.visibility.VisibilityLibrary +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.builder.users.ViewerFeatures +import com.twitter.visibility.builder.media.MediaFeatures +import com.twitter.visibility.builder.media.MediaMetadataFeatures +import com.twitter.visibility.builder.media.StratoMediaLabelMaps +import com.twitter.visibility.common.MediaMetadataSource +import com.twitter.visibility.common.MediaSafetyLabelMapSource +import com.twitter.visibility.common.UserSource +import com.twitter.visibility.features.FeatureMap +import com.twitter.visibility.generators.TombstoneGenerator +import com.twitter.visibility.models.ContentId.MediaId +import com.twitter.visibility.rules.EvaluationContext +import com.twitter.visibility.rules.providers.ProvidedEvaluationContext +import com.twitter.visibility.rules.utils.ShimUtils + +object MediaVisibilityLibrary { + type Type = MediaVisibilityRequest => Stitch[VisibilityResult] + + def apply( + visibilityLibrary: VisibilityLibrary, + userSource: UserSource, + tombstoneGenerator: TombstoneGenerator, + stratoClient: StratoClient, + ): Type = { + val libraryStatsReceiver = visibilityLibrary.statsReceiver + val vfEngineCounter = libraryStatsReceiver.counter("vf_engine_requests") + val vfLatencyOverallStat = libraryStatsReceiver.stat("vf_latency_overall") + val vfLatencyStitchRunStat = libraryStatsReceiver.stat("vf_latency_stitch_run") + + val stratoClientStatsReceiver = libraryStatsReceiver.scope("strato") + + val mediaMetadataFeatures = new MediaMetadataFeatures( + MediaMetadataSource.fromStrato(stratoClient, stratoClientStatsReceiver), + libraryStatsReceiver) + + val mediaLabelMaps = new StratoMediaLabelMaps( + MediaSafetyLabelMapSource.fromStrato(stratoClient, stratoClientStatsReceiver)) + val mediaFeatures = new MediaFeatures(mediaLabelMaps, libraryStatsReceiver) + + val viewerFeatures = new ViewerFeatures(userSource, libraryStatsReceiver) + + { r: MediaVisibilityRequest => + vfEngineCounter.incr() + + val contentId = MediaId(r.mediaKey.toStringKey) + val languageCode = r.viewerContext.requestLanguageCode.getOrElse("en") + + val featureMap = visibilityLibrary.featureMapBuilder( + Seq( + viewerFeatures.forViewerContext(r.viewerContext), + mediaFeatures.forGenericMediaKey(r.mediaKey), + mediaMetadataFeatures.forGenericMediaKey(r.mediaKey), + ) + ) + + val evaluationContext = ProvidedEvaluationContext.injectRuntimeRulesIntoEvaluationContext( + evaluationContext = EvaluationContext( + r.safetyLevel, + visibilityLibrary.getParams(r.viewerContext, r.safetyLevel), + visibilityLibrary.statsReceiver) + ) + + val preFilteredFeatureMap = + ShimUtils.preFilterFeatureMap(featureMap, r.safetyLevel, contentId, evaluationContext) + + val elapsed = Stopwatch.start() + FeatureMap.resolve(preFilteredFeatureMap, libraryStatsReceiver).flatMap { + resolvedFeatureMap => + vfLatencyStitchRunStat.add(elapsed().inMilliseconds) + + visibilityLibrary + .runRuleEngine( + contentId, + resolvedFeatureMap, + r.viewerContext, + r.safetyLevel + ) + .map(tombstoneGenerator(_, languageCode)) + .onSuccess(_ => vfLatencyOverallStat.add(elapsed().inMilliseconds)) + } + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/media/MediaVisibilityRequest.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/media/MediaVisibilityRequest.scala new file mode 100644 index 000000000..0f6b78f77 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/media/MediaVisibilityRequest.scala @@ -0,0 +1,10 @@ +package com.twitter.visibility.interfaces.media + +import com.twitter.mediaservices.media_util.GenericMediaKey +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.ViewerContext + +case class MediaVisibilityRequest( + mediaKey: GenericMediaKey, + safetyLevel: SafetyLevel, + viewerContext: ViewerContext) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/notifications/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/notifications/BUILD new file mode 100644 index 000000000..37b3f5dba --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/notifications/BUILD @@ -0,0 +1,31 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "3rdparty/jvm/com/google/guava", + "notificationservice/common/src/main/scala/com/twitter/notificationservice/model:alias", + "notificationservice/common/src/main/scala/com/twitter/notificationservice/model/notification", + "src/thrift/com/twitter/gizmoduck:thrift-scala", + "src/thrift/com/twitter/socialgraph:thrift-scala", + "src/thrift/com/twitter/spam/rtf:safety-label-scala", + "stitch/stitch-core", + "stitch/stitch-gizmoduck/src/main/scala", + "stitch/stitch-socialgraph", + "stitch/stitch-socialgraph/src/main/scala", + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/common/src/main/thrift/com/twitter/visibility:action-scala", + "visibility/lib/src/main/resources/config", + "visibility/lib/src/main/scala/com/twitter/visibility", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/tweets", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/users", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/params", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/common/tweets", + ], + exports = [ + "visibility/lib/src/main/scala/com/twitter/visibility", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/notifications/NotificationVFRequest.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/notifications/NotificationVFRequest.scala new file mode 100644 index 000000000..edf92aeab --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/notifications/NotificationVFRequest.scala @@ -0,0 +1,9 @@ +package com.twitter.visibility.interfaces.notifications + +import com.twitter.visibility.models.ContentId +import com.twitter.visibility.models.SafetyLevel + +case class NotificationVFRequest( + recipientId: Long, + subject: ContentId.UserId, + safetyLevel: SafetyLevel) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/notifications/NotificationsFilteringResponse.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/notifications/NotificationsFilteringResponse.scala new file mode 100644 index 000000000..bc500cce6 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/notifications/NotificationsFilteringResponse.scala @@ -0,0 +1,13 @@ +package com.twitter.visibility.interfaces.notifications + +import com.twitter.visibility.features.Feature +import com.twitter.visibility.rules.Action +import scala.collection.immutable.Set + +sealed trait NotificationsFilteringResponse + +case object Allow extends NotificationsFilteringResponse + +case class Filtered(action: Action) extends NotificationsFilteringResponse + +case class Failed(features: Set[Feature[_]]) extends NotificationsFilteringResponse diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/notifications/NotificationsPlatformFilteringResponse.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/notifications/NotificationsPlatformFilteringResponse.scala new file mode 100644 index 000000000..0a5624e49 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/notifications/NotificationsPlatformFilteringResponse.scala @@ -0,0 +1,13 @@ +package com.twitter.visibility.interfaces.notifications + +import com.twitter.visibility.features.Feature +import com.twitter.visibility.rules.Action + +trait NotificationsPlatformFilteringResponse + +case object AllowVerdict extends NotificationsPlatformFilteringResponse + +case class FilteredVerdict(action: Action) extends NotificationsPlatformFilteringResponse + +case class FailedVerdict(featuresMap: Map[Feature[_], String]) + extends NotificationsPlatformFilteringResponse diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/notifications/NotificationsPlatformVisibilityLibrary.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/notifications/NotificationsPlatformVisibilityLibrary.scala new file mode 100644 index 000000000..bdd2f59f1 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/notifications/NotificationsPlatformVisibilityLibrary.scala @@ -0,0 +1,157 @@ +package com.twitter.visibility.interfaces.notifications + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.servo.util.Gate +import com.twitter.stitch.Stitch +import com.twitter.util.Throwables +import com.twitter.visibility.VisibilityLibrary +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.builder.tweets.CommunityNotificationFeatures +import com.twitter.visibility.builder.tweets.UnmentionNotificationFeatures +import com.twitter.visibility.builder.users.AuthorDeviceFeatures +import com.twitter.visibility.builder.users.AuthorFeatures +import com.twitter.visibility.builder.users.RelationshipFeatures +import com.twitter.visibility.builder.users.ViewerAdvancedFilteringFeatures +import com.twitter.visibility.builder.users.ViewerFeatures +import com.twitter.visibility.common.UserDeviceSource +import com.twitter.visibility.common.UserRelationshipSource +import com.twitter.visibility.common.UserSource +import com.twitter.visibility.features.AuthorUserLabels +import com.twitter.visibility.features.Feature +import com.twitter.visibility.features.FeatureMap +import com.twitter.visibility.models.ViewerContext +import com.twitter.visibility.rules.State.FeatureFailed +import com.twitter.visibility.rules.State.MissingFeature +import com.twitter.visibility.rules.Action +import com.twitter.visibility.rules.RuleResult +import com.twitter.visibility.rules.{Allow => AllowAction} + +object NotificationsPlatformVisibilityLibrary { + type NotificationsPlatformVFType = + NotificationVFRequest => Stitch[NotificationsPlatformFilteringResponse] + + private val AllowResponse: Stitch[NotificationsPlatformFilteringResponse] = + Stitch.value(AllowVerdict) + + def apply( + userSource: UserSource, + userRelationshipSource: UserRelationshipSource, + userDeviceSource: UserDeviceSource, + visibilityLibrary: VisibilityLibrary, + enableShimFeatureHydration: Gate[Unit] = Gate.False + ): NotificationsPlatformVFType = { + val libraryStatsReceiver = visibilityLibrary.statsReceiver + val vfEngineCounter = libraryStatsReceiver.counter("vf_engine_requests") + + val authorFeatures = new AuthorFeatures(userSource, libraryStatsReceiver) + val authorDeviceFeatures = new AuthorDeviceFeatures(userDeviceSource, libraryStatsReceiver) + val viewerFeatures = new ViewerFeatures(userSource, libraryStatsReceiver) + + val viewerAdvancedFilteringFeatures = + new ViewerAdvancedFilteringFeatures(userSource, libraryStatsReceiver) + val relationshipFeatures = + new RelationshipFeatures(userRelationshipSource, libraryStatsReceiver) + + val isShimFeatureHydrationEnabled = enableShimFeatureHydration() + + def runRuleEngine(candidate: NotificationVFRequest): Stitch[VisibilityResult] = { + val featureMap = + visibilityLibrary.featureMapBuilder( + Seq( + viewerFeatures.forViewerId(Some(candidate.recipientId)), + viewerAdvancedFilteringFeatures.forViewerId(Some(candidate.recipientId)), + authorFeatures.forAuthorId(candidate.subject.id), + authorDeviceFeatures.forAuthorId(candidate.subject.id), + relationshipFeatures.forAuthorId(candidate.subject.id, Some(candidate.recipientId)), + CommunityNotificationFeatures.ForNonCommunityTweetNotification, + UnmentionNotificationFeatures.ForNonUnmentionNotificationFeatures + ) + ) + + vfEngineCounter.incr() + + if (isShimFeatureHydrationEnabled) { + FeatureMap.resolve(featureMap, libraryStatsReceiver).flatMap { resolvedFeatureMap => + visibilityLibrary.runRuleEngine( + contentId = candidate.subject, + featureMap = resolvedFeatureMap, + viewerContext = + ViewerContext.fromContextWithViewerIdFallback(Some(candidate.recipientId)), + safetyLevel = candidate.safetyLevel + ) + } + } else { + visibilityLibrary.runRuleEngine( + contentId = candidate.subject, + featureMap = featureMap, + viewerContext = + ViewerContext.fromContextWithViewerIdFallback(Some(candidate.recipientId)), + safetyLevel = candidate.safetyLevel + ) + } + } + + { + case candidate: NotificationVFRequest => + runRuleEngine(candidate).flatMap(failCloseForFailures(_, libraryStatsReceiver)) + case _ => + AllowResponse + } + } + + private def failCloseForFailures( + visibilityResult: VisibilityResult, + stats: StatsReceiver + ): Stitch[NotificationsPlatformFilteringResponse] = { + lazy val vfEngineSuccess = stats.counter("vf_engine_success") + lazy val vfEngineFailures = stats.counter("vf_engine_failures") + lazy val vfEngineFailuresMissing = stats.scope("vf_engine_failures").counter("missing") + lazy val vfEngineFailuresFailed = stats.scope("vf_engine_failures").counter("failed") + lazy val vfEngineFiltered = stats.counter("vf_engine_filtered") + + val isFailedOrMissingFeature: RuleResult => Boolean = { + case RuleResult(_, FeatureFailed(features)) => + !(features.contains(AuthorUserLabels) && features.size == 1) + case RuleResult(_, MissingFeature(_)) => true + case _ => false + } + + val failedRuleResults = + visibilityResult.ruleResultMap.values.filter(isFailedOrMissingFeature(_)) + + val (failedFeatures, missingFeatures) = failedRuleResults.partition { + case RuleResult(_, FeatureFailed(_)) => true + case RuleResult(_, MissingFeature(_)) => false + case _ => false + } + + val failedOrMissingFeatures: Map[Feature[_], String] = failedRuleResults + .collect { + case RuleResult(_, FeatureFailed(features)) => + features.map { + case (feature: Feature[_], throwable: Throwable) => + feature -> Throwables.mkString(throwable).mkString(" -> ") + }.toSet + case RuleResult(_, MissingFeature(features)) => features.map(_ -> "Feature missing.") + }.flatten.toMap + + visibilityResult.verdict match { + case AllowAction if failedOrMissingFeatures.isEmpty => + vfEngineSuccess.incr() + AllowResponse + case AllowAction if failedOrMissingFeatures.nonEmpty => + vfEngineFailures.incr() + if (missingFeatures.nonEmpty) { + vfEngineFailuresMissing.incr() + } + if (failedFeatures.nonEmpty) { + vfEngineFailuresFailed.incr() + } + + Stitch.value(FailedVerdict(failedOrMissingFeatures)) + case action: Action => + vfEngineFiltered.incr() + Stitch.value(FilteredVerdict(action)) + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/notifications/NotificationsVisibilityLibrary.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/notifications/NotificationsVisibilityLibrary.scala new file mode 100644 index 000000000..c6b99044c --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/notifications/NotificationsVisibilityLibrary.scala @@ -0,0 +1,181 @@ +package com.twitter.visibility.interfaces.notifications + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.notificationservice.model.notification.Notification +import com.twitter.notificationservice.model.notification.NotificationType +import com.twitter.notificationservice.model.notification.SimpleActivityNotification +import com.twitter.servo.util.Gate +import com.twitter.stitch.Stitch +import com.twitter.visibility.VisibilityLibrary +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.builder.tweets.CommunityNotificationFeatures +import com.twitter.visibility.builder.tweets.UnmentionNotificationFeatures +import com.twitter.visibility.builder.users.AuthorDeviceFeatures +import com.twitter.visibility.builder.users.AuthorFeatures +import com.twitter.visibility.builder.users.RelationshipFeatures +import com.twitter.visibility.builder.users.ViewerAdvancedFilteringFeatures +import com.twitter.visibility.builder.users.ViewerFeatures +import com.twitter.visibility.common.TweetSource +import com.twitter.visibility.common.UserDeviceSource +import com.twitter.visibility.common.UserRelationshipSource +import com.twitter.visibility.common.UserSource +import com.twitter.visibility.features.AuthorUserLabels +import com.twitter.visibility.features.FeatureMap +import com.twitter.visibility.models.ContentId.NotificationId +import com.twitter.visibility.models.SafetyLevel.NotificationsWriterV2 +import com.twitter.visibility.models.ViewerContext +import com.twitter.visibility.rules.State.FeatureFailed +import com.twitter.visibility.rules.State.MissingFeature +import com.twitter.visibility.rules.Action +import com.twitter.visibility.rules.RuleResult +import com.twitter.visibility.rules.{Allow => AllowAction} + +object NotificationsVisibilityLibrary { + type Type = Notification => Stitch[NotificationsFilteringResponse] + + private val AllowResponse: Stitch[NotificationsFilteringResponse] = Stitch.value(Allow) + + def isApplicableOrganicNotificationType(notificationType: NotificationType): Boolean = { + NotificationType.isTlsActivityType(notificationType) || + NotificationType.isReactionType(notificationType) + } + + def apply( + visibilityLibrary: VisibilityLibrary, + userSource: UserSource, + userRelationshipSource: UserRelationshipSource, + userDeviceSource: UserDeviceSource, + tweetSource: TweetSource, + enableShimFeatureHydration: Gate[Unit] = Gate.False, + enableCommunityTweetHydration: Gate[Long] = Gate.False, + enableUnmentionHydration: Gate[Long] = Gate.False, + ): Type = { + val libraryStatsReceiver = visibilityLibrary.statsReceiver + lazy val vfEngineCounter = libraryStatsReceiver.counter("vf_engine_requests") + + val authorFeatures = new AuthorFeatures(userSource, libraryStatsReceiver) + val authorDeviceFeatures = new AuthorDeviceFeatures(userDeviceSource, libraryStatsReceiver) + val viewerFeatures = new ViewerFeatures(userSource, libraryStatsReceiver) + val communityNotificationFeatures = + new CommunityNotificationFeatures( + tweetSource, + enableCommunityTweetHydration, + libraryStatsReceiver) + + val unmentionNotificationFeatures = new UnmentionNotificationFeatures( + tweetSource = tweetSource, + enableUnmentionHydration = enableUnmentionHydration, + statsReceiver = libraryStatsReceiver + ) + + val viewerAdvancedFilteringFeatures = + new ViewerAdvancedFilteringFeatures(userSource, libraryStatsReceiver) + val relationshipFeatures = + new RelationshipFeatures(userRelationshipSource, libraryStatsReceiver) + + val isShimFeatureHydrationEnabled = enableShimFeatureHydration() + + def runRuleEngine( + visibilityLibrary: VisibilityLibrary, + candidate: Notification + ): Stitch[VisibilityResult] = { + candidate match { + case notification: SimpleActivityNotification[_] => + vfEngineCounter.incr() + + val featureMap = visibilityLibrary.featureMapBuilder( + Seq( + viewerFeatures.forViewerId(Some(notification.target)), + viewerAdvancedFilteringFeatures.forViewerId(Some(notification.target)), + authorFeatures.forAuthorId(notification.subjectId), + authorDeviceFeatures.forAuthorId(notification.subjectId), + relationshipFeatures + .forAuthorId(notification.subjectId, Some(notification.target)), + communityNotificationFeatures.forNotification(notification), + unmentionNotificationFeatures.forNotification(notification) + ) + ) + + if (isShimFeatureHydrationEnabled) { + FeatureMap.resolve(featureMap, libraryStatsReceiver).flatMap { resolvedFeatureMap => + visibilityLibrary.runRuleEngine( + contentId = + featureMap = resolvedFeatureMap, + viewerContext = + ViewerContext.fromContextWithViewerIdFallback(Some(notification.target)), + safetyLevel = NotificationsWriterV2 + ) + } + } else { + visibilityLibrary.runRuleEngine( + contentId = NotificationId(tweetId = None), + featureMap = featureMap, + viewerContext = + ViewerContext.fromContextWithViewerIdFallback(Some(notification.target)), + safetyLevel = NotificationsWriterV2 + ) + } + } + } + + { + case candidate if isApplicableOrganicNotificationType(candidate.notificationType) => + runRuleEngine(visibilityLibrary, candidate) + .flatMap(failCloseForFailures(_, libraryStatsReceiver)) + case _ => + AllowResponse + } + } + + def failCloseForFailures( + visibilityResult: VisibilityResult, + stats: StatsReceiver + ): Stitch[NotificationsFilteringResponse] = { + lazy val vfEngineSuccess = stats.counter("vf_engine_success") + lazy val vfEngineFailures = stats.counter("vf_engine_failures") + lazy val vfEngineFailuresMissing = stats.scope("vf_engine_failures").counter("missing") + lazy val vfEngineFailuresFailed = stats.scope("vf_engine_failures").counter("failed") + lazy val vfEngineFiltered = stats.counter("vf_engine_filtered") + + val isFailedOrMissingFeature: RuleResult => Boolean = { + case RuleResult(_, FeatureFailed(features)) => + !(features.contains(AuthorUserLabels) && features.size == 1) + case RuleResult(_, MissingFeature(_)) => true + case _ => false + } + + val failedRuleResults = + visibilityResult.ruleResultMap.values.filter(isFailedOrMissingFeature(_)) + + val (failedFeatures, missingFeatures) = failedRuleResults.partition { + case RuleResult(_, FeatureFailed(_)) => true + case RuleResult(_, MissingFeature(_)) => false + case _ => false + } + + val failedOrMissingFeatures = failedRuleResults + .collect { + case RuleResult(_, FeatureFailed(features)) => features.keySet + case RuleResult(_, MissingFeature(features)) => features + }.toSet.flatten + + visibilityResult.verdict match { + case AllowAction if failedOrMissingFeatures.isEmpty => + vfEngineSuccess.incr() + AllowResponse + case AllowAction if failedOrMissingFeatures.nonEmpty => + vfEngineFailures.incr() + if (missingFeatures.nonEmpty) { + vfEngineFailuresMissing.incr() + } + if (failedFeatures.nonEmpty) { + vfEngineFailuresFailed.incr() + } + + Stitch.value(Failed(failedOrMissingFeatures)) + case action: Action => + vfEngineFiltered.incr() + Stitch.value(Filtered(action)) + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/push_service/BUILD.bazel b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/push_service/BUILD.bazel new file mode 100644 index 000000000..0436ca3c7 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/push_service/BUILD.bazel @@ -0,0 +1,31 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "3rdparty/jvm/com/twitter/storehaus:core", + "src/thrift/com/twitter/spam/rtf:tweet-rtf-event-scala", + "src/thrift/com/twitter/tweetypie:events-scala", + "src/thrift/com/twitter/tweetypie:tweet-scala", + "stitch/stitch-core", + "stitch/stitch-tweetypie/src/main/scala", + "strato/src/main/scala/com/twitter/strato/catalog", + "strato/src/main/scala/com/twitter/strato/client", + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/common/src/main/scala/com/twitter/visibility/common/stitch", + "visibility/common/src/main/thrift/com/twitter/visibility:action-scala", + "visibility/lib/src/main/resources/config", + "visibility/lib/src/main/scala/com/twitter/visibility", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/media", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/tweets", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/users", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/params", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/common/tweets", + ], + exports = [ + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/lib/src/main/scala/com/twitter/visibility", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/push_service/PushServiceSafetyLabelMapFetcher.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/push_service/PushServiceSafetyLabelMapFetcher.scala new file mode 100644 index 000000000..f7ce43392 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/push_service/PushServiceSafetyLabelMapFetcher.scala @@ -0,0 +1,21 @@ +package com.twitter.visibility.interfaces.push_service + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.spam.rtf.thriftscala.SafetyLabelMap +import com.twitter.stitch.Stitch +import com.twitter.strato.client.{Client => StratoClient} +import com.twitter.strato.thrift.ScroogeConvImplicits._ +import com.twitter.visibility.common.stitch.StitchHelpers + +object PushServiceSafetyLabelMapFetcher { + val Column = "frigate/magicrecs/tweetSafetyLabels" + + def apply( + client: StratoClient, + statsReceiver: StatsReceiver + ): Long => Stitch[Option[SafetyLabelMap]] = { + val stats = statsReceiver.scope("strato_tweet_safety_labels") + lazy val fetcher = client.fetcher[Long, SafetyLabelMap](Column) + tweetId => StitchHelpers.observe(stats)(fetcher.fetch(tweetId).map(_.v)) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/push_service/PushServiceVisibilityLibrary.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/push_service/PushServiceVisibilityLibrary.scala new file mode 100644 index 000000000..42b2d16ca --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/push_service/PushServiceVisibilityLibrary.scala @@ -0,0 +1,179 @@ +package com.twitter.visibility.interfaces.push_service + +import com.twitter.gizmoduck.thriftscala.User +import com.twitter.servo.util.Gate +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.stitch.Stitch +import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult +import com.twitter.storehaus.ReadableStore +import com.twitter.strato.client.{Client => StratoClient} +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.visibility.VisibilityLibrary +import com.twitter.visibility.builder.tweets.TweetFeatures +import com.twitter.visibility.builder.tweets.StratoTweetLabelMaps +import com.twitter.visibility.builder.users.AuthorFeatures +import com.twitter.visibility.builder.users.RelationshipFeatures +import com.twitter.visibility.builder.users.ViewerFeatures +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.common._ +import com.twitter.visibility.common.UserRelationshipSource +import com.twitter.visibility.common.UserSource +import com.twitter.visibility.features.FeatureMap +import com.twitter.visibility.features.TweetIsInnerQuotedTweet +import com.twitter.visibility.features.TweetIsRetweet +import com.twitter.visibility.features.TweetIsSourceTweet +import com.twitter.visibility.interfaces.push_service.PushServiceVisibilityLibraryUtil._ +import com.twitter.visibility.models.ContentId +import com.twitter.visibility.models.ViewerContext + +object TweetType extends Enumeration { + type TweetType = Value + val ORIGINAL, SOURCE, QUOTED = Value +} +import com.twitter.visibility.interfaces.push_service.TweetType._ + +object PushServiceVisibilityLibrary { + type Type = PushServiceVisibilityRequest => Stitch[PushServiceVisibilityResponse] + + def apply( + visibilityLibrary: VisibilityLibrary, + userSource: UserSource, + userRelationshipSource: UserRelationshipSource, + stratoClient: StratoClient, + enableParityTest: Gate[Unit] = Gate.False, + cachedTweetyPieStoreV2: ReadableStore[Long, TweetyPieResult] = ReadableStore.empty, + safeCachedTweetyPieStoreV2: ReadableStore[Long, TweetyPieResult] = ReadableStore.empty, + )( + implicit statsReceiver: StatsReceiver + ): Type = { + val stats = statsReceiver.scope("push_service_vf") + val candidateTweetCounter = stats.counter("request_cnt") + val allowedTweetCounter = stats.counter("allow_cnt") + val droppedTweetCounter = stats.counter("drop_cnt") + val failedTweetCounter = stats.counter("fail_cnt") + val authorLabelsEmptyCount = stats.counter("author_labels_empty_cnt") + val authorLabelsCount = stats.counter("author_labels_cnt") + + val tweetLabelMaps = new StratoTweetLabelMaps( + SafetyLabelMapSource.fromSafetyLabelMapFetcher( + PushServiceSafetyLabelMapFetcher(stratoClient, stats))) + + val viewerFeatures = new ViewerFeatures(UserSource.empty, stats) + val tweetFeatures = new TweetFeatures(tweetLabelMaps, stats) + val authorFeatures = new AuthorFeatures(userSource, stats) + val relationshipFeatures = new RelationshipFeatures(UserRelationshipSource.empty, stats) + + val parityTester = new PushServiceVisibilityLibraryParity( + cachedTweetyPieStoreV2, + safeCachedTweetyPieStoreV2 + )(statsReceiver) + + def buildFeatureMap( + request: PushServiceVisibilityRequest, + tweet: Tweet, + tweetType: TweetType, + author: Option[User] = None, + ): FeatureMap = { + val authorId = author.map(_.id) orElse getAuthorId(tweet) + (author.map(authorFeatures.forAuthor(_)) orElse + getAuthorId(tweet).map(authorFeatures.forAuthorId(_))) match { + case Some(authorVisibilityFeatures) => + visibilityLibrary.featureMapBuilder( + Seq( + viewerFeatures.forViewerContext(ViewerContext.fromContextWithViewerIdFallback(None)), + tweetFeatures.forTweet(tweet), + authorVisibilityFeatures, + relationshipFeatures.forAuthorId(authorId.get, None), + _.withConstantFeature(TweetIsInnerQuotedTweet, tweetType == QUOTED), + _.withConstantFeature(TweetIsRetweet, request.isRetweet), + _.withConstantFeature(TweetIsSourceTweet, tweetType == SOURCE) + ) + ) + case _ => + visibilityLibrary.featureMapBuilder( + Seq( + viewerFeatures.forViewerContext(ViewerContext.fromContextWithViewerIdFallback(None)), + tweetFeatures.forTweet(tweet), + _.withConstantFeature(TweetIsInnerQuotedTweet, tweetType == QUOTED), + _.withConstantFeature(TweetIsRetweet, request.isRetweet), + _.withConstantFeature(TweetIsSourceTweet, tweetType == SOURCE) + ) + ) + } + } + + def runRuleEngineForTweet( + request: PushServiceVisibilityRequest, + tweet: Tweet, + tweetType: TweetType, + author: Option[User] = None, + ): Stitch[VisibilityResult] = { + val featureMap = buildFeatureMap(request, tweet, tweetType, author) + val contentId = ContentId.TweetId(tweet.id) + visibilityLibrary.runRuleEngine( + contentId, + featureMap, + request.viewerContext, + request.safetyLevel) + } + + def runRuleEngineForAuthor( + request: PushServiceVisibilityRequest, + tweet: Tweet, + tweetType: TweetType, + author: Option[User] = None, + ): Stitch[VisibilityResult] = { + val featureMap = buildFeatureMap(request, tweet, tweetType, author) + val authorId = author.map(_.id).getOrElse(getAuthorId(tweet).get) + val contentId = ContentId.UserId(authorId) + visibilityLibrary.runRuleEngine( + contentId, + featureMap, + request.viewerContext, + request.safetyLevel) + } + + def getAllVisibilityFilters( + request: PushServiceVisibilityRequest + ): Stitch[PushServiceVisibilityResponse] = { + val tweetResult = + runRuleEngineForTweet(request, request.tweet, ORIGINAL, Some(request.author)) + val authorResult = + runRuleEngineForAuthor(request, request.tweet, ORIGINAL, Some(request.author)) + val sourceTweetResult = request.sourceTweet + .map(runRuleEngineForTweet(request, _, SOURCE).map(Some(_))).getOrElse(Stitch.None) + val quotedTweetResult = request.quotedTweet + .map(runRuleEngineForTweet(request, _, QUOTED).map(Some(_))).getOrElse(Stitch.None) + + Stitch.join(tweetResult, authorResult, sourceTweetResult, quotedTweetResult).map { + case (tweetResult, authorResult, sourceTweetResult, quotedTweetResult) => + PushServiceVisibilityResponse( + tweetResult, + authorResult, + sourceTweetResult, + quotedTweetResult) + } + } + + { request: PushServiceVisibilityRequest => + candidateTweetCounter.incr() + + request.author.labels match { + case Some(labels) if (!labels._1.isEmpty) => authorLabelsCount.incr() + case _ => authorLabelsEmptyCount.incr() + } + + val response = getAllVisibilityFilters(request) + .onSuccess { response => + if (response.shouldAllow) allowedTweetCounter.incr() else droppedTweetCounter.incr() + }.onFailure { _ => failedTweetCounter.incr() } + + if (enableParityTest()) { + response.applyEffect { resp => Stitch.async(parityTester.runParityTest(request, resp)) } + } else { + response + } + + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/push_service/PushServiceVisibilityLibraryParity.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/push_service/PushServiceVisibilityLibraryParity.scala new file mode 100644 index 000000000..cc1b27e24 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/push_service/PushServiceVisibilityLibraryParity.scala @@ -0,0 +1,74 @@ +package com.twitter.visibility.interfaces.push_service + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.stitch.Stitch +import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult +import com.twitter.storehaus.ReadableStore +import com.twitter.logging.Logger +import com.twitter.visibility.models.SafetyLevel + +class PushServiceVisibilityLibraryParity( + magicRecsV2tweetyPieStore: ReadableStore[Long, TweetyPieResult], + magicRecsAggressiveV2tweetyPieStore: ReadableStore[Long, TweetyPieResult] +)( + implicit statsReceiver: StatsReceiver) { + + private val stats = statsReceiver.scope("push_service_vf_parity") + private val requests = stats.counter("requests") + private val equal = stats.counter("equal") + private val notEqual = stats.counter("notEqual") + private val failures = stats.counter("failures") + private val bothAllow = stats.counter("bothAllow") + private val bothReject = stats.counter("bothReject") + private val onlyTweetypieRejects = stats.counter("onlyTweetypieRejects") + private val onlyPushServiceRejects = stats.counter("onlyPushServiceRejects") + + val log = Logger.get("pushservice_vf_parity") + + def runParityTest( + req: PushServiceVisibilityRequest, + resp: PushServiceVisibilityResponse + ): Stitch[Unit] = { + requests.incr() + getTweetypieResult(req).map { tweetypieResult => + val isSameVerdict = (tweetypieResult == resp.shouldAllow) + isSameVerdict match { + case true => equal.incr() + case false => notEqual.incr() + } + (tweetypieResult, resp.shouldAllow) match { + case (true, true) => bothAllow.incr() + case (true, false) => onlyPushServiceRejects.incr() + case (false, true) => onlyTweetypieRejects.incr() + case (false, false) => bothReject.incr() + } + + resp.getDropRules.foreach { dropRule => + stats.counter(s"rules/${dropRule.name}/requests").incr() + stats + .counter( + s"rules/${dropRule.name}/" ++ (if (isSameVerdict) "equal" else "notEqual")).incr() + } + + if (!isSameVerdict) { + val dropRuleNames = resp.getDropRules.map("<<" ++ _.name ++ ">>").mkString(",") + val safetyLevelStr = req.safetyLevel match { + case SafetyLevel.MagicRecsAggressiveV2 => "aggr" + case _ => " " + } + log.info( + s"ttweetId:${req.tweet.id} () push:${resp.shouldAllow}, tweety:${tweetypieResult}, rules=[${dropRuleNames}] lvl=${safetyLevelStr}") + } + } + + } + + def getTweetypieResult(request: PushServiceVisibilityRequest): Stitch[Boolean] = { + val tweetypieStore = request.safetyLevel match { + case SafetyLevel.MagicRecsAggressiveV2 => magicRecsAggressiveV2tweetyPieStore + case _ => magicRecsV2tweetyPieStore + } + Stitch.callFuture( + tweetypieStore.get(request.tweet.id).onFailure(_ => failures.incr()).map(x => x.isDefined)) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/push_service/PushServiceVisibilityLibraryUtil.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/push_service/PushServiceVisibilityLibraryUtil.scala new file mode 100644 index 000000000..0f0321afe --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/push_service/PushServiceVisibilityLibraryUtil.scala @@ -0,0 +1,57 @@ +package com.twitter.visibility.interfaces.push_service + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.rules.Rule +import com.twitter.visibility.rules.RuleResult +import com.twitter.visibility.rules.State + +object PushServiceVisibilityLibraryUtil { + def ruleEnabled(ruleResult: RuleResult): Boolean = { + ruleResult.state match { + case State.Disabled => false + case State.ShortCircuited => false + case _ => true + } + } + def getMissingFeatures(ruleResult: RuleResult): Set[String] = { + ruleResult.state match { + case State.MissingFeature(features) => features.map(f => f.name) + case _ => Set.empty + } + } + def getMissingFeatureCounts(results: Seq[VisibilityResult]): Map[String, Int] = { + results + .flatMap(_.ruleResultMap.values.toList) + .flatMap(getMissingFeatures(_).toList).groupBy(identity).mapValues(_.length) + } + + def logAllStats( + response: PushServiceVisibilityResponse + )( + implicit statsReceiver: StatsReceiver + ) = { + val rulesStatsReceiver = statsReceiver.scope("rules") + logStats(response.tweetVisibilityResult, rulesStatsReceiver.scope("tweet")) + logStats(response.authorVisibilityResult, rulesStatsReceiver.scope("author")) + } + + def logStats(result: VisibilityResult, statsReceiver: StatsReceiver) = { + result.ruleResultMap.toList + .filter { case (_, ruleResult) => ruleEnabled(ruleResult) } + .flatMap { case (rule, ruleResult) => getCounters(rule, ruleResult) } + .foreach(statsReceiver.counter(_).incr()) + } + + def getCounters(rule: Rule, ruleResult: RuleResult): List[String] = { + val missingFeatures = getMissingFeatures(ruleResult) + List(s"${rule.name}/${ruleResult.action.name}") ++ + missingFeatures.map(feat => s"${rule.name}/${feat}") ++ + missingFeatures + } + + def getAuthorId(tweet: Tweet): Option[Long] = tweet.coreData.map(_.userId) + def isRetweet(tweet: Tweet): Boolean = tweet.coreData.flatMap(_.share).isDefined + def isQuotedTweet(tweet: Tweet): Boolean = tweet.quotedTweet.isDefined +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/push_service/PushServiceVisibilityRequest.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/push_service/PushServiceVisibilityRequest.scala new file mode 100644 index 000000000..b773deec9 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/push_service/PushServiceVisibilityRequest.scala @@ -0,0 +1,19 @@ +package com.twitter.visibility.interfaces.push_service + +import com.twitter.gizmoduck.thriftscala.User +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.ViewerContext + +case class PushServiceVisibilityRequest( + tweet: Tweet, + author: User, + viewerContext: ViewerContext, + safetyLevel: SafetyLevel, + sourceTweet: Option[Tweet] = None, + quotedTweet: Option[Tweet] = None, + isRetweet: Boolean = false, + isInnerQuotedTweet: Boolean = false, + isSourceTweet: Boolean = false, + isOutOfNetworkTweet: Boolean = true, +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/push_service/PushServiceVisibilityResponse.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/push_service/PushServiceVisibilityResponse.scala new file mode 100644 index 000000000..a3598fd61 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/push_service/PushServiceVisibilityResponse.scala @@ -0,0 +1,52 @@ +package com.twitter.visibility.interfaces.push_service + +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.rules.Action +import com.twitter.visibility.rules.Allow +import com.twitter.visibility.rules.Drop +import com.twitter.visibility.rules.Rule +import com.twitter.visibility.rules.RuleResult + +case class PushServiceVisibilityResponse( + tweetVisibilityResult: VisibilityResult, + authorVisibilityResult: VisibilityResult, + sourceTweetVisibilityResult: Option[VisibilityResult] = None, + quotedTweetVisibilityResult: Option[VisibilityResult] = None, +) { + + def allVisibilityResults: List[VisibilityResult] = { + List( + Some(tweetVisibilityResult), + Some(authorVisibilityResult), + sourceTweetVisibilityResult, + quotedTweetVisibilityResult, + ).collect { case Some(result) => result } + } + + val shouldAllow: Boolean = !allVisibilityResults.exists(isDrop(_)) + + def isDrop(response: VisibilityResult): Boolean = response.verdict match { + case _: Drop => true + case Allow => false + case _ => false + } + def isDrop(response: Option[VisibilityResult]): Boolean = response.map(isDrop(_)).getOrElse(false) + + def getDropRules(visibilityResult: VisibilityResult): List[Rule] = { + val ruleResultMap = visibilityResult.ruleResultMap + val ruleResults = ruleResultMap.toList + val denyRules = ruleResults.collect { case (rule, RuleResult(Drop(_, _), _)) => rule } + denyRules + } + def getAuthorDropRules: List[Rule] = getDropRules(authorVisibilityResult) + def getTweetDropRules: List[Rule] = getDropRules(tweetVisibilityResult) + def getDropRules: List[Rule] = getAuthorDropRules ++ getTweetDropRules + def getVerdict: Action = { + if (isDrop(authorVisibilityResult)) authorVisibilityResult.verdict + else tweetVisibilityResult.verdict + } + + def missingFeatures: Map[String, Int] = PushServiceVisibilityLibraryUtil.getMissingFeatureCounts( + Seq(tweetVisibilityResult, authorVisibilityResult)) + +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/search/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/search/BUILD new file mode 100644 index 000000000..2e758bc84 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/search/BUILD @@ -0,0 +1,34 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "3rdparty/jvm/com/twitter/src/java/com/twitter/logpipeline/client:logpipeline-event-publisher-thin", + "decider/src/main/scala", + "mediaservices/media-util/src/main/scala", + "servo/decider/src/main/scala", + "src/thrift/com/twitter/escherbird:media-annotation-structs-scala", + "src/thrift/com/twitter/spam/rtf:safety-level-scala", + "src/thrift/com/twitter/spam/rtf:tweet-rtf-event-scala", + "src/thrift/com/twitter/tweetypie:tweet-scala", + "stitch/stitch-core", + "strato/src/main/scala/com/twitter/strato/client", + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/lib/src/main/resources/config", + "visibility/lib/src/main/scala/com/twitter/visibility", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/media", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/tweets", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/users", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/configs", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/common/search", + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/common/tweets", + "visibility/lib/src/main/thrift/com/twitter/visibility/logging:vf-logging-scala", + ], + exports = [ + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/lib/src/main/scala/com/twitter/visibility", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/search/BatchSearchVisibilityRequest.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/search/BatchSearchVisibilityRequest.scala new file mode 100644 index 000000000..37a294825 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/search/BatchSearchVisibilityRequest.scala @@ -0,0 +1,9 @@ +package com.twitter.visibility.interfaces.search + +import com.twitter.visibility.interfaces.common.search.SearchVFRequestContext +import com.twitter.visibility.models.ViewerContext + +case class BatchSearchVisibilityRequest( + tweetContexts: Seq[TweetContext], + viewerContext: ViewerContext, + searchVFRequestContext: SearchVFRequestContext) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/search/BatchSearchVisibilityResponse.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/search/BatchSearchVisibilityResponse.scala new file mode 100644 index 000000000..3eb7918dc --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/search/BatchSearchVisibilityResponse.scala @@ -0,0 +1,5 @@ +package com.twitter.visibility.interfaces.search + +case class BatchSearchVisibilityResponse( + visibilityResults: Map[Long, CombinedVisibilityResult], + failedTweetIds: Seq[Long]) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/search/CombinedVisibilityResult.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/search/CombinedVisibilityResult.scala new file mode 100644 index 000000000..e71841783 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/search/CombinedVisibilityResult.scala @@ -0,0 +1,7 @@ +package com.twitter.visibility.interfaces.search + +import com.twitter.visibility.builder.VisibilityResult + +case class CombinedVisibilityResult( + tweetVisibilityResult: VisibilityResult, + quotedTweetVisibilityResult: Option[VisibilityResult]) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/search/SearchVisibilityLibrary.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/search/SearchVisibilityLibrary.scala new file mode 100644 index 000000000..ea46ab741 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/search/SearchVisibilityLibrary.scala @@ -0,0 +1,466 @@ +package com.twitter.visibility.interfaces.search + +import com.twitter.decider.Decider +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.mediaservices.media_util.GenericMediaKey +import com.twitter.servo.util.Gate +import com.twitter.stitch.Stitch +import com.twitter.strato.client.{Client => StratoClient} +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.util.Return +import com.twitter.util.Stopwatch +import com.twitter.util.Try +import com.twitter.visibility.VisibilityLibrary +import com.twitter.visibility.builder.VerdictLogger +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.builder.media.MediaFeatures +import com.twitter.visibility.builder.media.StratoMediaLabelMaps +import com.twitter.visibility.builder.tweets._ +import com.twitter.visibility.builder.users.AuthorFeatures +import com.twitter.visibility.builder.users.RelationshipFeatures +import com.twitter.visibility.builder.users.ViewerFeatures +import com.twitter.visibility.common.MediaSafetyLabelMapSource +import com.twitter.visibility.common.MisinformationPolicySource +import com.twitter.visibility.common.SafetyLabelMapSource +import com.twitter.visibility.common.TrustedFriendsSource +import com.twitter.visibility.common.UserRelationshipSource +import com.twitter.visibility.common.UserSource +import com.twitter.visibility.rules.ComposableActions._ +import com.twitter.visibility.configapi.configs.VisibilityDeciderGates +import com.twitter.visibility.features.FeatureMap +import com.twitter.visibility.features.TweetIsInnerQuotedTweet +import com.twitter.visibility.features.TweetIsRetweet +import com.twitter.visibility.features.TweetIsSourceTweet +import com.twitter.visibility.interfaces.common.search.SearchVFRequestContext +import com.twitter.visibility.interfaces.search.SearchVisibilityLibrary.EvaluateTweet +import com.twitter.visibility.interfaces.search.SearchVisibilityLibrary.RequestTweetId +import com.twitter.visibility.interfaces.search.TweetType.EvaluateTweetType +import com.twitter.visibility.logging.thriftscala.VFLibType +import com.twitter.visibility.models.ContentId +import com.twitter.visibility.models.ContentId.BlenderTweetId +import com.twitter.visibility.models.ContentId.TweetId +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.SafetyLevel.toThrift +import com.twitter.visibility.models.ViewerContext +import com.twitter.visibility.rules.Action +import com.twitter.visibility.rules.Allow +import com.twitter.visibility.rules.Drop +import com.twitter.visibility.rules.Interstitial +import com.twitter.visibility.rules.TweetInterstitial + +object TweetType extends Enumeration { + type EvaluateTweetType = Value + val REQUEST: TweetType.Value = Value(1) + val QUOTED: TweetType.Value = Value(2) + val SOURCE: TweetType.Value = Value(3) +} + +import com.twitter.visibility.interfaces.search.TweetType._ + +object SearchVisibilityLibrary { + type RequestTweetId = Long + type EvaluateTweetId = Long + type EvaluateTweet = Tweet + + def buildWithStratoClient( + visibilityLibrary: VisibilityLibrary, + decider: Decider, + stratoClient: StratoClient, + userSource: UserSource, + userRelationshipSource: UserRelationshipSource + ): SearchVisibilityLibrary = new SearchVisibilityLibrary( + visibilityLibrary, + decider, + stratoClient, + userSource, + userRelationshipSource, + None + ) + + def buildWithSafetyLabelMapSource( + visibilityLibrary: VisibilityLibrary, + decider: Decider, + stratoClient: StratoClient, + userSource: UserSource, + userRelationshipSource: UserRelationshipSource, + safetyLabelMapSource: SafetyLabelMapSource + ): SearchVisibilityLibrary = new SearchVisibilityLibrary( + visibilityLibrary, + decider, + stratoClient, + userSource, + userRelationshipSource, + Some(safetyLabelMapSource) + ) + + def createVerdictLogger( + enableVerdictLogger: Gate[Unit], + decider: Decider, + statsReceiver: StatsReceiver + ): VerdictLogger = { + if (enableVerdictLogger()) { + VerdictLogger(statsReceiver, decider) + } else { + VerdictLogger.Empty + } + } + + def scribeVisibilityVerdict( + result: CombinedVisibilityResult, + enableVerdictScribing: Gate[Unit], + verdictLogger: VerdictLogger, + viewerId: Option[Long], + safetyLevel: SafetyLevel + ): Unit = if (enableVerdictScribing()) { + verdictLogger.scribeVerdict( + visibilityResult = result.tweetVisibilityResult, + viewerId = viewerId, + safetyLevel = toThrift(safetyLevel), + vfLibType = VFLibType.SearchVisibilityLibrary) + + result.quotedTweetVisibilityResult.map(quotedTweetVisibilityResult => + verdictLogger.scribeVerdict( + visibilityResult = quotedTweetVisibilityResult, + viewerId = viewerId, + safetyLevel = toThrift(safetyLevel), + vfLibType = VFLibType.SearchVisibilityLibrary)) + } +} + +class SearchVisibilityLibrary( + visibilityLibrary: VisibilityLibrary, + decider: Decider, + stratoClient: StratoClient, + userSource: UserSource, + userRelationshipSource: UserRelationshipSource, + safetyLabelMapSourceOption: Option[SafetyLabelMapSource]) { + + val libraryStatsReceiver = visibilityLibrary.statsReceiver + val stratoClientStatsReceiver = visibilityLibrary.statsReceiver.scope("strato") + val vfEngineCounter = libraryStatsReceiver.counter("vf_engine_requests") + val svlRequestCounter = libraryStatsReceiver.counter("svl_requests") + val vfLatencyOverallStat = libraryStatsReceiver.stat("vf_latency_overall") + val vfLatencyStitchBuildStat = libraryStatsReceiver.stat("vf_latency_stitch_build") + val vfLatencyStitchRunStat = libraryStatsReceiver.stat("vf_latency_stitch_run") + val visibilityDeciderGates = VisibilityDeciderGates(decider) + val verdictLogger = SearchVisibilityLibrary.createVerdictLogger( + visibilityDeciderGates.enableVerdictLoggerSVL, + decider, + libraryStatsReceiver) + + val tweetLabels = safetyLabelMapSourceOption match { + case Some(safetyLabelMapSource) => + new StratoTweetLabelMaps(safetyLabelMapSource) + case None => + new StratoTweetLabelMaps( + SafetyLabelMapSource.fromStrato(stratoClient, stratoClientStatsReceiver)) + } + + val mediaLabelMaps = new StratoMediaLabelMaps( + MediaSafetyLabelMapSource.fromStrato(stratoClient, stratoClientStatsReceiver)) + + val tweetFeatures = new TweetFeatures(tweetLabels, libraryStatsReceiver) + val searchContextFeatures = new SearchContextFeatures(libraryStatsReceiver) + val authorFeatures = new AuthorFeatures(userSource, libraryStatsReceiver) + val viewerFeatures = new ViewerFeatures(userSource, libraryStatsReceiver) + val relationshipFeatures = + new RelationshipFeatures(userRelationshipSource, libraryStatsReceiver) + val misinfoPolicySource = + MisinformationPolicySource.fromStrato(stratoClient, stratoClientStatsReceiver) + val misinfoPolicyFeatures = + new MisinformationPolicyFeatures(misinfoPolicySource, stratoClientStatsReceiver) + val exclusiveTweetFeatures = + new ExclusiveTweetFeatures(userRelationshipSource, libraryStatsReceiver) + val mediaFeatures = new MediaFeatures(mediaLabelMaps, libraryStatsReceiver) + val trustedFriendsTweetFeatures = new TrustedFriendsFeatures( + trustedFriendsSource = TrustedFriendsSource.fromStrato(stratoClient, stratoClientStatsReceiver)) + val editTweetFeatures = new EditTweetFeatures(libraryStatsReceiver) + + def batchProcessSearchVisibilityRequest( + batchSvRequest: BatchSearchVisibilityRequest + ): Stitch[BatchSearchVisibilityResponse] = { + val elapsed = Stopwatch.start() + svlRequestCounter.incr() + + val response: Stitch[BatchSearchVisibilityResponse] = + batchSvRequest.tweetContexts.groupBy(tweetContext => tweetContext.safetyLevel) map { + case (safetyLevel: SafetyLevel, tweetContexts: Seq[TweetContext]) => + val (contentsToBeEvaluated, contentVisResultTypes) = + extractContentsToBeEvaluated(tweetContexts, batchSvRequest.viewerContext) + + getVisibilityResult( + contentsToBeEvaluated, + safetyLevel, + batchSvRequest.viewerContext, + batchSvRequest.searchVFRequestContext) + .map { contentVisResults: Seq[Try[VisibilityResult]] => + (contentVisResultTypes zip contentVisResults) + .map(handleVisibilityResultByTweetType) + .groupBy { + case (requestTweetId: RequestTweetId, (_, _)) => requestTweetId + }.mapValues(combineVisibilityResult) + }.onSuccess(res => + res.values.flatten.foreach(_ => + SearchVisibilityLibrary.scribeVisibilityVerdict( + _, + visibilityDeciderGates.enableVerdictScribingSVL, + verdictLogger, + batchSvRequest.viewerContext.userId, + safetyLevel))) + } reduceLeft { (left, right) => + Stitch.joinMap(left, right)((visResultsA, visResultsB) => visResultsA ++ visResultsB) + } map { visResults => + val (succeed, failed) = visResults.partition { case (_, visResult) => visResult.nonEmpty } + val failedTweetIds: Seq[Long] = failed.keys.toSeq + BatchSearchVisibilityResponse( + visibilityResults = succeed.mapValues(visResult => visResult.get), + failedTweetIds = failedTweetIds + ) + } + + val runStitchStartMs = elapsed().inMilliseconds + val buildStitchStatMs = elapsed().inMilliseconds + vfLatencyStitchBuildStat.add(buildStitchStatMs) + + response + .onSuccess(_ => { + val overallMs = elapsed().inMilliseconds + vfLatencyOverallStat.add(overallMs) + val stitchRunMs = elapsed().inMilliseconds - runStitchStartMs + vfLatencyStitchRunStat.add(stitchRunMs) + }) + } + + private def extractContentsToBeEvaluated( + tweetContexts: Seq[TweetContext], + viewerContext: ViewerContext + ): ( + Seq[(TweetContext, EvaluateTweetType, EvaluateTweet, ContentId)], + Seq[ + (RequestTweetId, EvaluateTweetType) + ] + ) = { + val contentsToBeEvaluated: Seq[ + (TweetContext, EvaluateTweetType, EvaluateTweet, ContentId) + ] = tweetContexts.map(tc => + ( + tc, + REQUEST, + tc.tweet, + getContentId( + viewerId = viewerContext.userId, + authorId = tc.tweet.coreData.get.userId, + tweet = tc.tweet))) ++ + tweetContexts + .filter(tc => tc.quotedTweet.nonEmpty).map(tc => + ( + tc, + QUOTED, + tc.quotedTweet.get, + getContentId( + viewerId = viewerContext.userId, + authorId = tc.quotedTweet.get.coreData.get.userId, + tweet = tc.quotedTweet.get))) ++ + tweetContexts + .filter(tc => tc.retweetSourceTweet.nonEmpty).map(tc => + ( + tc, + SOURCE, + tc.retweetSourceTweet.get, + getContentId( + viewerId = viewerContext.userId, + authorId = tc.retweetSourceTweet.get.coreData.get.userId, + tweet = tc.retweetSourceTweet.get))) + + val contentVisResultTypes: Seq[(RequestTweetId, EvaluateTweetType)] = { + contentsToBeEvaluated.map { + case (tc: TweetContext, tweetType: EvaluateTweetType, _, _) => + (tc.tweet.id, tweetType) + } + } + + (contentsToBeEvaluated, contentVisResultTypes) + } + + private def combineVisibilityResult( + visResults: Seq[(RequestTweetId, (EvaluateTweetType, Try[VisibilityResult]))] + ): Option[CombinedVisibilityResult] = { + visResults.sortBy(_._2._1)(ValueOrdering) match { + case Seq( + (_, (REQUEST, Return(requestTweetVisResult))), + (_, (QUOTED, Return(quotedTweetVisResult))), + (_, (SOURCE, Return(sourceTweetVisResult)))) => + requestTweetVisResult.verdict match { + case Allow => + Some(CombinedVisibilityResult(sourceTweetVisResult, Some(quotedTweetVisResult))) + case _ => + Some(CombinedVisibilityResult(requestTweetVisResult, Some(quotedTweetVisResult))) + } + case Seq( + (_, (REQUEST, Return(requestTweetVisResult))), + (_, (QUOTED, Return(quotedTweetVisResult)))) => + Some(CombinedVisibilityResult(requestTweetVisResult, Some(quotedTweetVisResult))) + case Seq( + (_, (REQUEST, Return(requestTweetVisResult))), + (_, (SOURCE, Return(sourceTweetVisResult)))) => + requestTweetVisResult.verdict match { + case Allow => + Some(CombinedVisibilityResult(sourceTweetVisResult, None)) + case _ => + Some(CombinedVisibilityResult(requestTweetVisResult, None)) + } + + case Seq((_, (REQUEST, Return(requestTweetVisResult)))) => + Some(CombinedVisibilityResult(requestTweetVisResult, None)) + case _ => None + } + } + + private def getVisibilityResult( + contents: Seq[(TweetContext, EvaluateTweetType, EvaluateTweet, ContentId)], + safetyLevel: SafetyLevel, + viewerContext: ViewerContext, + svRequestContext: SearchVFRequestContext + ): Stitch[Seq[Try[VisibilityResult]]] = { + + val contentContext: Map[ContentId, (TweetContext, EvaluateTweetType, EvaluateTweet)] = + contents.map { + case ( + tweetContext: TweetContext, + tweetType: EvaluateTweetType, + tweet: EvaluateTweet, + contentId: ContentId) => + contentId -> ((tweetContext, tweetType, tweet)) + }.toMap + + val featureMapProvider: (ContentId, SafetyLevel) => FeatureMap = { + case (contentId: ContentId, _) => + val (tweetContext, tweetType, tweet) = contentContext(contentId) + buildFeatureMap( + evaluatedTweet = tweet, + tweetType = tweetType, + tweetContext = tweetContext, + viewerContext = viewerContext, + svRequestContext = svRequestContext + ) + } + + visibilityLibrary.runRuleEngineBatch( + contentIds = contents.map { case (_, _, _, id: ContentId) => id }, + featureMapProvider = featureMapProvider, + viewerContext = viewerContext, + safetyLevel = safetyLevel + ) + } + + private def getContentId(viewerId: Option[Long], authorId: Long, tweet: Tweet): ContentId = { + if (viewerId.contains(authorId)) + TweetId(tweet.id) + else BlenderTweetId(tweet.id) + } + + private def buildFeatureMap( + evaluatedTweet: Tweet, + tweetType: EvaluateTweetType, + tweetContext: TweetContext, + viewerContext: ViewerContext, + svRequestContext: SearchVFRequestContext + ): FeatureMap = { + val authorId = evaluatedTweet.coreData.get.userId + val viewerId = viewerContext.userId + val isRetweet = + if (tweetType.equals(REQUEST)) tweetContext.retweetSourceTweet.nonEmpty else false + val isSourceTweet = tweetType.equals(SOURCE) + val isQuotedTweet = tweetType.equals(QUOTED) + val tweetMediaKeys: Seq[GenericMediaKey] = evaluatedTweet.media + .getOrElse(Seq.empty) + .flatMap(_.mediaKey.map(GenericMediaKey.apply)) + + visibilityLibrary.featureMapBuilder( + Seq( + viewerFeatures + .forViewerSearchContext(svRequestContext, viewerContext), + relationshipFeatures.forAuthorId(authorId, viewerId), + tweetFeatures.forTweet(evaluatedTweet), + mediaFeatures.forMediaKeys(tweetMediaKeys), + authorFeatures.forAuthorId(authorId), + searchContextFeatures.forSearchContext(svRequestContext), + _.withConstantFeature(TweetIsRetweet, isRetweet), + misinfoPolicyFeatures.forTweet(evaluatedTweet, viewerContext), + exclusiveTweetFeatures.forTweet(evaluatedTweet, viewerContext), + trustedFriendsTweetFeatures.forTweet(evaluatedTweet, viewerId), + editTweetFeatures.forTweet(evaluatedTweet), + _.withConstantFeature(TweetIsInnerQuotedTweet, isQuotedTweet), + _.withConstantFeature(TweetIsSourceTweet, isSourceTweet), + ) + ) + } + + private def handleVisibilityResultByTweetType( + zipVisResult: ((RequestTweetId, EvaluateTweetType), Try[VisibilityResult]) + ): (RequestTweetId, (EvaluateTweetType, Try[VisibilityResult])) = { + zipVisResult match { + case ((id: RequestTweetId, REQUEST), Return(visResult)) => + (id, (REQUEST, Return(handleComposableVisibilityResult(visResult)))) + case ((id: RequestTweetId, QUOTED), Return(visResult)) => + ( + id, + ( + QUOTED, + Return( + handleInnerQuotedTweetVisibilityResult(handleComposableVisibilityResult(visResult))))) + case ((id: RequestTweetId, SOURCE), Return(visResult)) => + (id, (SOURCE, Return(handleComposableVisibilityResult(visResult)))) + case ((id: RequestTweetId, tweetType: EvaluateTweetType), result: Try[VisibilityResult]) => + (id, (tweetType, result)) + } + } + + private def handleComposableVisibilityResult(result: VisibilityResult): VisibilityResult = { + if (result.secondaryVerdicts.nonEmpty) { + result.copy(verdict = composeActions(result.verdict, result.secondaryVerdicts)) + } else { + result + } + } + + private def composeActions(primary: Action, secondary: Seq[Action]): Action = { + if (primary.isComposable && secondary.nonEmpty) { + val actions = Seq[Action] { primary } ++ secondary + val interstitialOpt = Action.getFirstInterstitial(actions: _*) + val softInterventionOpt = Action.getFirstSoftIntervention(actions: _*) + val limitedEngagementsOpt = Action.getFirstLimitedEngagements(actions: _*) + val avoidOpt = Action.getFirstAvoid(actions: _*) + + val numActions = + Seq[Option[_]](interstitialOpt, softInterventionOpt, limitedEngagementsOpt, avoidOpt) + .count(_.isDefined) + if (numActions > 1) { + TweetInterstitial( + interstitialOpt, + softInterventionOpt, + limitedEngagementsOpt, + None, + avoidOpt + ) + } else { + primary + } + } else { + primary + } + } + + private def handleInnerQuotedTweetVisibilityResult( + result: VisibilityResult + ): VisibilityResult = { + val newVerdict: Action = + result.verdict match { + case interstitial: Interstitial => Drop(interstitial.reason) + case ComposableActionsWithInterstitial(tweetInterstitial) => Drop(tweetInterstitial.reason) + case verdict => verdict + } + + result.copy(verdict = newVerdict) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/search/TweetContext.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/search/TweetContext.scala new file mode 100644 index 000000000..b2afda131 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/search/TweetContext.scala @@ -0,0 +1,10 @@ +package com.twitter.visibility.interfaces.search + +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.visibility.models.SafetyLevel + +case class TweetContext( + tweet: Tweet, + quotedTweet: Option[Tweet], + retweetSourceTweet: Option[Tweet] = None, + safetyLevel: SafetyLevel) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/spaces/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/spaces/BUILD new file mode 100644 index 000000000..ae742d9be --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/spaces/BUILD @@ -0,0 +1,37 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "decider/src/main/scala", + "servo/decider/src/main/scala", + "src/scala/com/twitter/search/blender/services/strato", + "src/thrift/com/twitter/escherbird:media-annotation-structs-scala", + "src/thrift/com/twitter/gizmoduck:user-thrift-scala", + "src/thrift/com/twitter/spam/rtf:safety-level-scala", + "src/thrift/com/twitter/spam/rtf:tweet-rtf-event-scala", + "src/thrift/com/twitter/tweetypie:tweet-scala", + "stitch/stitch-core", + "strato/src/main/scala/com/twitter/strato/client", + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/common/src/main/scala/com/twitter/visibility/common/stitch", + "visibility/lib/src/main/resources/config", + "visibility/lib/src/main/scala/com/twitter/visibility", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/common", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/spaces", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/tweets", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/users", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/configs", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/common/tweets", + "visibility/lib/src/main/scala/com/twitter/visibility/rules/providers", + "visibility/lib/src/main/scala/com/twitter/visibility/rules/utils", + "visibility/lib/src/main/scala/com/twitter/visibility/util", + ], + exports = [ + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/lib/src/main/scala/com/twitter/visibility", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/spaces/SpaceVisibilityLibrary.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/spaces/SpaceVisibilityLibrary.scala new file mode 100644 index 000000000..8d0273095 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/spaces/SpaceVisibilityLibrary.scala @@ -0,0 +1,117 @@ +package com.twitter.visibility.interfaces.spaces + +import com.twitter.servo.util.Gate +import com.twitter.stitch.Stitch +import com.twitter.strato.client.{Client => StratoClient} +import com.twitter.visibility.VisibilityLibrary +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.builder.common.MutedKeywordFeatures +import com.twitter.visibility.builder.spaces.SpaceFeatures +import com.twitter.visibility.builder.spaces.StratoSpaceLabelMaps +import com.twitter.visibility.builder.users.AuthorFeatures +import com.twitter.visibility.builder.users.RelationshipFeatures +import com.twitter.visibility.builder.users.ViewerFeatures +import com.twitter.visibility.common._ +import com.twitter.visibility.common.stitch.StitchHelpers +import com.twitter.visibility.features.FeatureMap +import com.twitter.visibility.models.ContentId.SpaceId +import com.twitter.visibility.models.ContentId.SpacePlusUserId +import com.twitter.visibility.rules.EvaluationContext +import com.twitter.visibility.rules.providers.ProvidedEvaluationContext +import com.twitter.visibility.rules.utils.ShimUtils + +object SpaceVisibilityLibrary { + type Type = SpaceVisibilityRequest => Stitch[VisibilityResult] + + def apply( + visibilityLibrary: VisibilityLibrary, + stratoClient: StratoClient, + userSource: UserSource, + userRelationshipSource: UserRelationshipSource, + enableVfFeatureHydrationSpaceShim: Gate[Unit] = Gate.False + ): Type = { + val libraryStatsReceiver = visibilityLibrary.statsReceiver + val stratoClientStatsReceiver = visibilityLibrary.statsReceiver.scope("strato") + val vfLatencyStatsReceiver = visibilityLibrary.statsReceiver.scope("vf_latency") + val vfEngineCounter = libraryStatsReceiver.counter("vf_engine_requests") + + val spaceLabelMaps = new StratoSpaceLabelMaps( + SpaceSafetyLabelMapSource.fromStrato(stratoClient, stratoClientStatsReceiver), + libraryStatsReceiver) + val audioSpaceSource = AudioSpaceSource.fromStrato(stratoClient, stratoClientStatsReceiver) + + val viewerFeatures = new ViewerFeatures(userSource, libraryStatsReceiver) + val authorFeatures = new AuthorFeatures(userSource, libraryStatsReceiver) + val relationshipFeatures = + new RelationshipFeatures(userRelationshipSource, libraryStatsReceiver) + val mutedKeywordFeatures = new MutedKeywordFeatures( + userSource, + userRelationshipSource, + KeywordMatcher.matcher(libraryStatsReceiver), + libraryStatsReceiver, + Gate.False + ) + val spaceFeatures = + new SpaceFeatures( + spaceLabelMaps, + authorFeatures, + relationshipFeatures, + mutedKeywordFeatures, + audioSpaceSource) + + { r: SpaceVisibilityRequest => + vfEngineCounter.incr() + + val isVfFeatureHydrationEnabled = enableVfFeatureHydrationSpaceShim() + val viewerId = r.viewerContext.userId + val authorIds: Option[Seq[Long]] = r.spaceHostAndAdminUserIds + val contentId = { + (viewerId, authorIds) match { + case (Some(viewer), Some(authors)) if authors.contains(viewer) => SpaceId(r.spaceId) + case _ => SpacePlusUserId(r.spaceId) + } + } + + val featureMap = + visibilityLibrary.featureMapBuilder( + Seq( + spaceFeatures.forSpaceAndAuthorIds(r.spaceId, viewerId, authorIds), + viewerFeatures.forViewerContext(r.viewerContext), + ) + ) + + val resp = if (isVfFeatureHydrationEnabled) { + val evaluationContext = ProvidedEvaluationContext.injectRuntimeRulesIntoEvaluationContext( + evaluationContext = EvaluationContext( + r.safetyLevel, + visibilityLibrary.getParams(r.viewerContext, r.safetyLevel), + visibilityLibrary.statsReceiver) + ) + + val preFilteredFeatureMap = + ShimUtils.preFilterFeatureMap(featureMap, r.safetyLevel, contentId, evaluationContext) + + FeatureMap + .resolve(preFilteredFeatureMap, libraryStatsReceiver).flatMap { resolvedFeatureMap => + visibilityLibrary + .runRuleEngine( + contentId, + resolvedFeatureMap, + r.viewerContext, + r.safetyLevel + ) + } + } else { + visibilityLibrary + .runRuleEngine( + contentId, + featureMap, + r.viewerContext, + r.safetyLevel + ) + } + + StitchHelpers.profileStitch(resp, Seq(vfLatencyStatsReceiver)) + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/spaces/SpaceVisibilityRequest.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/spaces/SpaceVisibilityRequest.scala new file mode 100644 index 000000000..7e8f91269 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/spaces/SpaceVisibilityRequest.scala @@ -0,0 +1,10 @@ +package com.twitter.visibility.interfaces.spaces + +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.ViewerContext + +case class SpaceVisibilityRequest( + spaceId: String, + safetyLevel: SafetyLevel, + viewerContext: ViewerContext, + spaceHostAndAdminUserIds: Option[Seq[Long]]) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/BUILD new file mode 100644 index 000000000..5e9b8cc14 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/BUILD @@ -0,0 +1,47 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "3rdparty/jvm/com/twitter/src/java/com/twitter/logpipeline/client:logpipeline-event-publisher-thin", + "decider/src/main/scala", + "featureswitches/featureswitches-core/src/main/scala", + "mediaservices/media-util/src/main/scala", + "servo/decider/src/main/scala", + "src/thrift/com/twitter/context:twitter-context-scala", + "src/thrift/com/twitter/escherbird:media-annotation-structs-scala", + "src/thrift/com/twitter/gizmoduck:user-thrift-scala", + "src/thrift/com/twitter/spam/rtf:safety-level-scala", + "src/thrift/com/twitter/spam/rtf:tweet-rtf-event-scala", + "src/thrift/com/twitter/tweetypie:tweet-scala", + "stitch/stitch-core", + "strato/src/main/scala/com/twitter/strato/catalog", + "strato/src/main/scala/com/twitter/strato/client", + "twitter-context/src/main/scala", + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/common/src/main/scala/com/twitter/visibility/common/actions/converter/scala", + "visibility/common/src/main/scala/com/twitter/visibility/common/tweets", + "visibility/common/src/main/thrift/com/twitter/visibility/tweets:tweets-scala", + "visibility/lib/src/main/resources/config", + "visibility/lib/src/main/scala/com/twitter/visibility", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/common", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/media", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/tweets", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/users", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/configs", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/generators", + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/common/tweets", + "visibility/lib/src/main/scala/com/twitter/visibility/interfaces/tweets/enrichments", + "visibility/lib/src/main/scala/com/twitter/visibility/rules/providers", + "visibility/lib/src/main/scala/com/twitter/visibility/rules/utils", + "visibility/lib/src/main/scala/com/twitter/visibility/util", + "visibility/lib/src/main/thrift/com/twitter/visibility/logging:vf-logging-scala", + ], + exports = [ + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/lib/src/main/scala/com/twitter/visibility", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/DeletedTweetVisibilityLibrary.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/DeletedTweetVisibilityLibrary.scala new file mode 100644 index 000000000..4ebc81ee6 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/DeletedTweetVisibilityLibrary.scala @@ -0,0 +1,59 @@ +package com.twitter.visibility.interfaces.tweets + +import com.twitter.decider.Decider +import com.twitter.stitch.Stitch +import com.twitter.visibility.VisibilityLibrary +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.features.TweetDeleteReason +import com.twitter.visibility.features.TweetIsInnerQuotedTweet +import com.twitter.visibility.features.TweetIsRetweet +import com.twitter.visibility.generators.TombstoneGenerator +import com.twitter.visibility.models.ContentId.DeleteTweetId +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.TweetDeleteReason.TweetDeleteReason +import com.twitter.visibility.models.ViewerContext + +object DeletedTweetVisibilityLibrary { + type Type = DeletedTweetVisibilityLibrary.Request => Stitch[VisibilityResult] + + case class Request( + tweetId: Long, + safetyLevel: SafetyLevel, + viewerContext: ViewerContext, + tweetDeleteReason: TweetDeleteReason, + isRetweet: Boolean, + isInnerQuotedTweet: Boolean, + ) + + def apply( + visibilityLibrary: VisibilityLibrary, + decider: Decider, + tombstoneGenerator: TombstoneGenerator, + ): Type = { + val vfEngineCounter = visibilityLibrary.statsReceiver.counter("vf_engine_requests") + + (request: Request) => { + vfEngineCounter.incr() + val contentId = DeleteTweetId(request.tweetId) + val language = request.viewerContext.requestLanguageCode.getOrElse("en") + + val featureMap = + visibilityLibrary.featureMapBuilder( + Seq( + _.withConstantFeature(TweetIsInnerQuotedTweet, request.isInnerQuotedTweet), + _.withConstantFeature(TweetIsRetweet, request.isRetweet), + _.withConstantFeature(TweetDeleteReason, request.tweetDeleteReason) + ) + ) + + visibilityLibrary + .runRuleEngine( + contentId, + featureMap, + request.viewerContext, + request.safetyLevel + ) + .map(tombstoneGenerator(_, language)) + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/QuotedTweetVisibilityLibrary.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/QuotedTweetVisibilityLibrary.scala new file mode 100644 index 000000000..dcdc960fb --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/QuotedTweetVisibilityLibrary.scala @@ -0,0 +1,150 @@ +package com.twitter.visibility.interfaces.tweets + +import com.twitter.decider.Decider +import com.twitter.servo.util.Gate +import com.twitter.stitch.Stitch +import com.twitter.visibility.VisibilityLibrary +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.builder.users.AuthorFeatures +import com.twitter.visibility.builder.users.QuotedTweetFeatures +import com.twitter.visibility.builder.users.RelationshipFeatures +import com.twitter.visibility.builder.users.ViewerFeatures +import com.twitter.visibility.common.UserRelationshipSource +import com.twitter.visibility.common.UserSource +import com.twitter.visibility.configapi.configs.VisibilityDeciderGates +import com.twitter.visibility.features.FeatureMap +import com.twitter.visibility.models.ContentId.QuotedTweetRelationship +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.UserUnavailableStateEnum +import com.twitter.visibility.models.ViewerContext +import com.twitter.visibility.rules.Drop +import com.twitter.visibility.rules.EvaluationContext +import com.twitter.visibility.rules.Reason.AuthorBlocksViewer +import com.twitter.visibility.rules.Reason.DeactivatedAuthor +import com.twitter.visibility.rules.Reason.ErasedAuthor +import com.twitter.visibility.rules.Reason.OffboardedAuthor +import com.twitter.visibility.rules.Reason.ProtectedAuthor +import com.twitter.visibility.rules.Reason.SuspendedAuthor +import com.twitter.visibility.rules.Reason.ViewerBlocksAuthor +import com.twitter.visibility.rules.Reason.ViewerHardMutedAuthor +import com.twitter.visibility.rules.Reason.ViewerMutesAuthor +import com.twitter.visibility.rules.providers.ProvidedEvaluationContext +import com.twitter.visibility.rules.utils.ShimUtils + +case class TweetAndAuthor(tweetId: Long, authorId: Long) + +case class QuotedTweetVisibilityRequest( + quotedTweet: TweetAndAuthor, + outerTweet: TweetAndAuthor, + viewerContext: ViewerContext, + safetyLevel: SafetyLevel) + +object QuotedTweetVisibilityLibrary { + + type Type = QuotedTweetVisibilityRequest => Stitch[VisibilityResult] + + def apply( + visibilityLibrary: VisibilityLibrary, + userSource: UserSource, + userRelationshipSource: UserRelationshipSource, + decider: Decider, + userStateVisibilityLibrary: UserUnavailableStateVisibilityLibrary.Type, + enableVfFeatureHydration: Gate[Unit] = Gate.False + ): Type = { + val libraryStatsReceiver = visibilityLibrary.statsReceiver + val visibilityDeciderGates = VisibilityDeciderGates(decider) + val vfEngineCounter = libraryStatsReceiver.counter("vf_engine_requests") + + { + case QuotedTweetVisibilityRequest(quotedTweet, outerTweet, viewerContext, safetyLevel) => + vfEngineCounter.incr() + val contentId = QuotedTweetRelationship( + outer = outerTweet.tweetId, + inner = quotedTweet.tweetId + ) + + val innerAuthorId = quotedTweet.authorId + val outerAuthorId = outerTweet.authorId + val viewerId = viewerContext.userId + val isFeatureHydrationInShimEnabled = enableVfFeatureHydration() + + val authorFeatures = new AuthorFeatures(userSource, libraryStatsReceiver) + val viewerFeatures = new ViewerFeatures(userSource, libraryStatsReceiver) + val relationshipFeatures = + new RelationshipFeatures(userRelationshipSource, libraryStatsReceiver) + val quotedTweetFeatures = + new QuotedTweetFeatures(relationshipFeatures, libraryStatsReceiver) + + val featureMap = visibilityLibrary.featureMapBuilder( + Seq( + viewerFeatures.forViewerContext(viewerContext), + authorFeatures.forAuthorId(innerAuthorId), + relationshipFeatures.forAuthorId(innerAuthorId, viewerId), + quotedTweetFeatures.forOuterAuthor(outerAuthorId, innerAuthorId) + ) + ) + + val resp = if (isFeatureHydrationInShimEnabled) { + val evaluationContext = ProvidedEvaluationContext.injectRuntimeRulesIntoEvaluationContext( + evaluationContext = EvaluationContext( + SafetyLevel.QuotedTweetRules, + visibilityLibrary.getParams(viewerContext, SafetyLevel.QuotedTweetRules), + visibilityLibrary.statsReceiver) + ) + + val preFilteredFeatureMap = + ShimUtils.preFilterFeatureMap( + featureMap, + SafetyLevel.QuotedTweetRules, + contentId, + evaluationContext) + + FeatureMap.resolve(preFilteredFeatureMap, libraryStatsReceiver).flatMap { + resolvedFeatureMap => + visibilityLibrary + .runRuleEngine( + contentId, + resolvedFeatureMap, + viewerContext, + SafetyLevel.QuotedTweetRules + ) + } + } else { + visibilityLibrary + .runRuleEngine( + contentId, + featureMap, + viewerContext, + SafetyLevel.QuotedTweetRules + ) + } + + resp.flatMap { visResult => + val userStateOpt = visResult.verdict match { + case Drop(DeactivatedAuthor, _) => Some(UserUnavailableStateEnum.Deactivated) + case Drop(OffboardedAuthor, _) => Some(UserUnavailableStateEnum.Offboarded) + case Drop(ErasedAuthor, _) => Some(UserUnavailableStateEnum.Erased) + case Drop(ProtectedAuthor, _) => Some(UserUnavailableStateEnum.Protected) + case Drop(SuspendedAuthor, _) => Some(UserUnavailableStateEnum.Suspended) + case Drop(AuthorBlocksViewer, _) => Some(UserUnavailableStateEnum.AuthorBlocksViewer) + case Drop(ViewerBlocksAuthor, _) => Some(UserUnavailableStateEnum.ViewerBlocksAuthor) + case Drop(ViewerMutesAuthor, _) => Some(UserUnavailableStateEnum.ViewerMutesAuthor) + case Drop(ViewerHardMutedAuthor, _) => Some(UserUnavailableStateEnum.ViewerMutesAuthor) + case _ => None + } + + userStateOpt + .map(userState => + userStateVisibilityLibrary( + UserUnavailableStateVisibilityRequest( + safetyLevel, + quotedTweet.tweetId, + viewerContext, + userState, + isRetweet = false, + isInnerQuotedTweet = true, + ))).getOrElse(Stitch.value(visResult)) + } + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/TweetVisibilityLibrary.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/TweetVisibilityLibrary.scala new file mode 100644 index 000000000..09d717b76 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/TweetVisibilityLibrary.scala @@ -0,0 +1,421 @@ +package com.twitter.visibility.interfaces.tweets + +import com.twitter.decider.Decider +import com.twitter.featureswitches.v2.FeatureSwitches +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.mediaservices.media_util.GenericMediaKey +import com.twitter.servo.util.Gate +import com.twitter.stitch.Stitch +import com.twitter.strato.client.{Client => StratoClient} +import com.twitter.util.Stopwatch +import com.twitter.visibility.VisibilityLibrary +import com.twitter.visibility.builder.VerdictLogger +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.builder.common.MutedKeywordFeatures +import com.twitter.visibility.builder.media._ +import com.twitter.visibility.builder.tweets.TweetVisibilityNudgeSourceWrapper +import com.twitter.visibility.builder.tweets._ +import com.twitter.visibility.builder.users.AuthorFeatures +import com.twitter.visibility.builder.users.RelationshipFeatures +import com.twitter.visibility.builder.users.ViewerFeatures +import com.twitter.visibility.builder.users.ViewerSearchSafetyFeatures +import com.twitter.visibility.builder.users.ViewerSensitiveMediaSettingsFeatures +import com.twitter.visibility.common._ +import com.twitter.visibility.common.actions.LimitedAction +import com.twitter.visibility.common.actions.LimitedActionType +import com.twitter.visibility.common.actions.LimitedActionsPolicy +import com.twitter.visibility.rules.ComposableActions._ +import com.twitter.visibility.configapi.configs.VisibilityDeciderGates +import com.twitter.visibility.features.TweetIsInnerQuotedTweet +import com.twitter.visibility.features.TweetIsRetweet +import com.twitter.visibility.features.TweetIsSourceTweet +import com.twitter.visibility.generators.LocalizedInterstitialGenerator +import com.twitter.visibility.generators.TombstoneGenerator +import com.twitter.visibility.interfaces.tweets.enrichments.ComplianceTweetNoticeEnrichment +import com.twitter.visibility.interfaces.tweets.enrichments.LimitedActionsPolicyEnrichment +import com.twitter.visibility.interfaces.tweets.enrichments.TweetVisibilityNudgeEnrichment +import com.twitter.visibility.logging.thriftscala.VFLibType +import com.twitter.visibility.models.ContentId.TweetId +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.SafetyLevel.toThrift +import com.twitter.visibility.rules._ + +object TweetVisibilityLibrary { + type Type = TweetVisibilityRequest => Stitch[VisibilityResult] + + def apply( + visibilityLibrary: VisibilityLibrary, + userSource: UserSource, + userRelationshipSource: UserRelationshipSource, + keywordMatcher: KeywordMatcher.Matcher, + invitedToConversationRepo: InvitedToConversationRepo, + decider: Decider, + stratoClient: StratoClient, + localizationSource: LocalizationSource, + tweetPerspectiveSource: TweetPerspectiveSource, + tweetMediaMetadataSource: TweetMediaMetadataSource, + tombstoneGenerator: TombstoneGenerator, + interstitialGenerator: LocalizedInterstitialGenerator, + limitedActionsFeatureSwitches: FeatureSwitches, + enableParityTest: Gate[Unit] = Gate.False + ): Type = { + val libraryStatsReceiver = visibilityLibrary.statsReceiver + val stratoClientStatsReceiver = visibilityLibrary.statsReceiver.scope("strato") + val vfEngineCounter = libraryStatsReceiver.counter("vf_engine_requests") + val vfLatencyOverallStat = libraryStatsReceiver.stat("vf_latency_overall") + val vfLatencyStitchBuildStat = libraryStatsReceiver.stat("vf_latency_stitch_build") + val vfLatencyStitchRunStat = libraryStatsReceiver.stat("vf_latency_stitch_run") + val visibilityDeciderGates = VisibilityDeciderGates(decider) + val verdictLogger = + createVerdictLogger( + visibilityDeciderGates.enableVerdictLoggerTVL, + decider, + libraryStatsReceiver) + + val tweetLabelMaps = new StratoTweetLabelMaps( + SafetyLabelMapSource.fromStrato(stratoClient, stratoClientStatsReceiver)) + + val mediaLabelMaps = new StratoMediaLabelMaps( + MediaSafetyLabelMapSource.fromStrato(stratoClient, stratoClientStatsReceiver)) + + val tweetFeatures = new TweetFeatures(tweetLabelMaps, libraryStatsReceiver) + val tweetPerspectiveFeatures = + new TweetPerspectiveFeatures(tweetPerspectiveSource, libraryStatsReceiver) + val mediaFeatures = new MediaFeatures(mediaLabelMaps, libraryStatsReceiver) + val tweetMediaMetadataFeatures = + new TweetMediaMetadataFeatures(tweetMediaMetadataSource, libraryStatsReceiver) + val authorFeatures = new AuthorFeatures(userSource, libraryStatsReceiver) + val viewerFeatures = new ViewerFeatures(userSource, libraryStatsReceiver) + val mutedKeywordFeatures = + new MutedKeywordFeatures( + userSource, + userRelationshipSource, + keywordMatcher, + libraryStatsReceiver, + visibilityDeciderGates.enableFollowCheckInMutedKeyword + ) + val relationshipFeatures = + new RelationshipFeatures(userRelationshipSource, libraryStatsReceiver) + val fonsrRelationshipFeatures = + new FosnrRelationshipFeatures( + tweetLabels = tweetLabelMaps, + userRelationshipSource = userRelationshipSource, + statsReceiver = libraryStatsReceiver) + val conversationControlFeatures = + new ConversationControlFeatures( + relationshipFeatures, + invitedToConversationRepo, + libraryStatsReceiver + ) + val exclusiveTweetFeatures = + new ExclusiveTweetFeatures(userRelationshipSource, libraryStatsReceiver) + + val viewerSearchSafetyFeatures = new ViewerSearchSafetyFeatures( + UserSearchSafetySource.fromStrato(stratoClient, stratoClientStatsReceiver), + libraryStatsReceiver) + + val viewerSensitiveMediaSettingsFeatures = new ViewerSensitiveMediaSettingsFeatures( + UserSensitiveMediaSettingsSource.fromStrato(stratoClient, stratoClientStatsReceiver), + libraryStatsReceiver) + + val toxicReplyFilterFeature = new ToxicReplyFilterFeature(statsReceiver = libraryStatsReceiver) + + val misinfoPolicySource = + MisinformationPolicySource.fromStrato(stratoClient, stratoClientStatsReceiver) + val misinfoPolicyFeatures = + new MisinformationPolicyFeatures(misinfoPolicySource, stratoClientStatsReceiver) + + val communityTweetFeatures = new CommunityTweetFeaturesV2( + communitiesSource = CommunitiesSource.fromStrato( + stratoClient, + stratoClientStatsReceiver + ) + ) + + val trustedFriendsTweetFeatures = new TrustedFriendsFeatures( + trustedFriendsSource = + TrustedFriendsSource.fromStrato(stratoClient, stratoClientStatsReceiver)) + + val editTweetFeatures = new EditTweetFeatures(libraryStatsReceiver) + + val parityTest = new TweetVisibilityLibraryParityTest(libraryStatsReceiver, stratoClient) + + val localizedNudgeSource = + LocalizedNudgeSource.fromLocalizationSource(localizationSource) + val tweetVisibilityNudgeFeatures = + new TweetVisibilityNudgeSourceWrapper(localizedNudgeSource) + + val localizedLimitedActionsSource = + LocalizedLimitedActionsSource.fromLocalizationSource(localizationSource) + + { r: TweetVisibilityRequest => + val elapsed = Stopwatch.start() + var runStitchStartMs = 0L + vfEngineCounter.incr() + + val contentId = TweetId(r.tweet.id) + val viewerId = r.viewerContext.userId + val authorId = coreData.userId + val tweetGenericMediaKeys = r.tweet.mediaRefs + .getOrElse(Seq.empty) + .flatMap { mediaRef => + GenericMediaKey.fromStringKey(mediaRef.genericMediaKey) + } + + val tpf = + tweetPerspectiveFeatures.forTweet( + r.tweet, + viewerId, + visibilityDeciderGates.enableFetchTweetReportedPerspective()) + + val featureMap = + visibilityLibrary.featureMapBuilder( + Seq( + mutedKeywordFeatures.forTweet(r.tweet, viewerId, authorId), + viewerFeatures.forViewerContext(r.viewerContext), + viewerSearchSafetyFeatures.forViewerId(viewerId), + viewerSensitiveMediaSettingsFeatures.forViewerId(viewerId), + relationshipFeatures.forAuthorId(authorId, viewerId), + fonsrRelationshipFeatures + .forTweetAndAuthorId(tweet = r.tweet, authorId = authorId, viewerId = viewerId), + tweetFeatures.forTweet(r.tweet), + tpf, + mediaFeatures.forMediaKeys(tweetGenericMediaKeys), + authorFeatures.forAuthorId(authorId), + conversationControlFeatures.forTweet(r.tweet, viewerId), + _.withConstantFeature(TweetIsInnerQuotedTweet, r.isInnerQuotedTweet), + _.withConstantFeature(TweetIsRetweet, r.isRetweet), + _.withConstantFeature(TweetIsSourceTweet, r.isSourceTweet), + misinfoPolicyFeatures.forTweet(r.tweet, r.viewerContext), + exclusiveTweetFeatures.forTweet(r.tweet, r.viewerContext), + communityTweetFeatures.forTweet(r.tweet, r.viewerContext), + tweetMediaMetadataFeatures + .forTweet( + r.tweet, + tweetGenericMediaKeys, + visibilityDeciderGates.enableFetchTweetMediaMetadata()), + trustedFriendsTweetFeatures.forTweet(r.tweet, viewerId), + editTweetFeatures.forTweet(r.tweet), + toxicReplyFilterFeature.forTweet(r.tweet, viewerId), + ) + ) + + val languageCode = r.viewerContext.requestLanguageCode.getOrElse("en") + val countryCode = r.viewerContext.requestCountryCode + + val response = visibilityLibrary + .runRuleEngine( + contentId, + featureMap, + r.viewerContext, + r.safetyLevel + ) + .map( + TweetVisibilityNudgeEnrichment( + _, + tweetVisibilityNudgeFeatures, + languageCode, + countryCode)) + .map(verdict => { + if (visibilityDeciderGates.enableBackendLimitedActions()) { + LimitedActionsPolicyEnrichment( + verdict, + localizedLimitedActionsSource, + languageCode, + countryCode, + limitedActionsFeatureSwitches, + libraryStatsReceiver) + } else { + verdict + } + }) + .map( + handleComposableVisibilityResult( + _, + visibilityDeciderGates.enableMediaInterstitialComposition(), + visibilityDeciderGates.enableBackendLimitedActions())) + .map(handleInnerQuotedTweetVisibilityResult(_, r.isInnerQuotedTweet)) + .map(tombstoneGenerator(_, languageCode)) + .map(interstitialGenerator(_, languageCode)) + .map(ComplianceTweetNoticeEnrichment(_, libraryStatsReceiver)) + .onSuccess(_ => { + val overallStatMs = elapsed().inMilliseconds + vfLatencyOverallStat.add(overallStatMs) + val runStitchEndMs = elapsed().inMilliseconds + vfLatencyStitchRunStat.add(runStitchEndMs - runStitchStartMs) + }) + .onSuccess( + scribeVisibilityVerdict( + _, + visibilityDeciderGates.enableVerdictScribingTVL, + verdictLogger, + r.viewerContext.userId, + r.safetyLevel)) + + runStitchStartMs = elapsed().inMilliseconds + val buildStitchStatMs = elapsed().inMilliseconds + vfLatencyStitchBuildStat.add(buildStitchStatMs) + + if (enableParityTest()) { + response.applyEffect { resp => + Stitch.async(parityTest.runParityTest(r, resp)) + } + } else { + response + } + } + } + + def handleComposableVisibilityResult( + result: VisibilityResult, + enableMediaInterstitialComposition: Boolean, + enableBackendLimitedActions: Boolean + ): VisibilityResult = { + if (result.secondaryVerdicts.nonEmpty || enableBackendLimitedActions) { + result.copy(verdict = composeActions( + result.verdict, + result.secondaryVerdicts, + enableMediaInterstitialComposition, + enableBackendLimitedActions)) + } else { + result + } + } + + def handleInnerQuotedTweetVisibilityResult( + result: VisibilityResult, + isInnerQuotedTweet: Boolean + ): VisibilityResult = { + val newVerdict: Action = + result.verdict match { + case Interstitial(Reason.Nsfw | Reason.NsfwMedia, _, _) if isInnerQuotedTweet => + Drop(Reason.Nsfw) + case ComposableActionsWithInterstitial(tweetInterstitial) + if isInnerQuotedTweet && (tweetInterstitial.reason == Reason.Nsfw || tweetInterstitial.reason == Reason.NsfwMedia) => + Drop(Reason.Nsfw) + case verdict => verdict + } + + result.copy(verdict = newVerdict) + } + + def hasTweetRules(safetyLevel: SafetyLevel): Boolean = RuleBase.hasTweetRules(safetyLevel) + + def composeActions( + primary: Action, + secondary: Seq[Action], + enableMediaInterstitialComposition: Boolean, + enableBackendLimitedActions: Boolean + ): Action = { + if (primary.isComposable && (secondary.nonEmpty || enableBackendLimitedActions)) { + val actions = Seq[Action] { primary } ++ secondary + val interstitialOpt = Action.getFirstInterstitial(actions: _*) + val softInterventionOpt = Action.getFirstSoftIntervention(actions: _*) + val downrankOpt = Action.getFirstDownrankHomeTimeline(actions: _*) + val avoidOpt = Action.getFirstAvoid(actions: _*) + val tweetVisibilityNudgeOpt = Action.getFirstTweetVisibilityNudge(actions: _*) + + val mediaInterstitialOpt = { + val firstMediaInterstitialOpt = Action.getFirstMediaInterstitial(actions: _*) + if (enableMediaInterstitialComposition && interstitialOpt != firstMediaInterstitialOpt) { + firstMediaInterstitialOpt + } else { + None + } + } + + val limitedEngagementsOpt = enableBackendLimitedActions match { + case true => buildCompositeLimitedEngagements(Action.getAllLimitedEngagements(actions: _*)) + case false => Action.getFirstLimitedEngagements(actions: _*) + } + + val abusiveQualityOpt = { + if (actions.contains(ConversationSectionAbusiveQuality)) { + Some(ConversationSectionAbusiveQuality) + } else { + None + } + } + + val numActions = + Seq[Option[_]]( + interstitialOpt, + softInterventionOpt, + limitedEngagementsOpt, + downrankOpt, + avoidOpt, + mediaInterstitialOpt, + tweetVisibilityNudgeOpt, + abusiveQualityOpt) + .count(_.isDefined) + if (numActions > 1) { + TweetInterstitial( + interstitialOpt, + softInterventionOpt, + limitedEngagementsOpt, + downrankOpt, + avoidOpt, + mediaInterstitialOpt, + tweetVisibilityNudgeOpt, + abusiveQualityOpt + ) + } else { + if (enableBackendLimitedActions) { + limitedEngagementsOpt.getOrElse(primary) + } else { + primary + } + } + } else { + primary + } + } + + def scribeVisibilityVerdict( + result: VisibilityResult, + enableVerdictScribing: Gate[Unit], + verdictLogger: VerdictLogger, + viewerId: Option[Long], + safetyLevel: SafetyLevel + ): Unit = if (enableVerdictScribing()) { + verdictLogger.scribeVerdict( + visibilityResult = result, + viewerId = viewerId, + safetyLevel = toThrift(safetyLevel), + vfLibType = VFLibType.TweetVisibilityLibrary) + } + + def buildCompositeLimitedEngagements( + limitedEngagements: Seq[IsLimitedEngagements] + ): Option[LimitedEngagements] = { + limitedEngagements.headOption.flatMap { limitedEngagement => + val distinctLimitedActions = limitedEngagements + .collect({ case IsLimitedEngagements(Some(policy), _) => policy.limitedActions }) + .flatten + .foldRight(Map.empty[LimitedActionType, LimitedAction])({ (limitedAction, acc) => + acc + ((limitedAction.limitedActionType, limitedAction)) + }) + .values + .toSeq + + if (distinctLimitedActions.nonEmpty) { + val limitedActionsPolicy = Some(LimitedActionsPolicy(distinctLimitedActions)) + Some(LimitedEngagements(limitedEngagement.getLimitedEngagementReason, limitedActionsPolicy)) + } else { + None + } + } + } + + def createVerdictLogger( + enableVerdictLogger: Gate[Unit], + decider: Decider, + statsReceiver: StatsReceiver + ): VerdictLogger = { + if (enableVerdictLogger()) { + VerdictLogger(statsReceiver, decider) + } else { + VerdictLogger.Empty + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/TweetVisibilityLibraryParityTest.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/TweetVisibilityLibraryParityTest.scala new file mode 100644 index 000000000..621ed9f9b --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/TweetVisibilityLibraryParityTest.scala @@ -0,0 +1,109 @@ +package com.twitter.visibility.interfaces.tweets + +import com.twitter.spam.rtf.{thriftscala => t} +import com.twitter.context.TwitterContext +import com.twitter.context.thriftscala.Viewer +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.stitch.Stitch +import com.twitter.strato.catalog.Fetch +import com.twitter.strato.client.Client +import com.twitter.strato.client.Fetcher +import com.twitter.strato.thrift.ScroogeConvImplicits._ +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.common.tweets.TweetVisibilityResultMapper +import com.twitter.visibility.models.SafetyLevel.toThrift +import com.twitter.visibility.models.ViewerContext +import com.twitter.visibility.thriftscala.TweetVisibilityResult + +class TweetVisibilityLibraryParityTest(statsReceiver: StatsReceiver, stratoClient: Client) { + + private val parityTestScope = statsReceiver.scope("tweet_visibility_library_parity") + private val requests = parityTestScope.counter("requests") + private val equal = parityTestScope.counter("equal") + private val incorrect = parityTestScope.counter("incorrect") + private val empty = parityTestScope.counter("empty") + private val failures = parityTestScope.counter("failures") + + private val fetcher: Fetcher[Long, t.SafetyLevel, TweetVisibilityResult] = + stratoClient.fetcher[Long, t.SafetyLevel, TweetVisibilityResult]( + "visibility/service/TweetVisibilityResult.Tweet" + ) + + def runParityTest( + req: TweetVisibilityRequest, + resp: VisibilityResult + ): Stitch[Unit] = { + requests.incr() + + val twitterContext = TwitterContext(TwitterContextPermit) + + val viewer: Option[Viewer] = { + + val remoteViewerContext = ViewerContext.fromContext + + if (remoteViewerContext != req.viewerContext) { + val updatedRemoteViewerContext = remoteViewerContext.copy( + userId = req.viewerContext.userId + ) + + if (updatedRemoteViewerContext == req.viewerContext) { + twitterContext() match { + case None => + Some(Viewer(userId = req.viewerContext.userId)) + case Some(v) => + Some(v.copy(userId = req.viewerContext.userId)) + } + } else { + None + } + } else { + None + } + } + + val tweetypieContext = TweetypieContext( + isQuotedTweet = req.isInnerQuotedTweet, + isRetweet = req.isRetweet, + hydrateConversationControl = req.hydrateConversationControl + ) + + val parityCheck: Stitch[Fetch.Result[TweetVisibilityResult]] = { + Stitch.callFuture { + TweetypieContext.let(tweetypieContext) { + viewer match { + case Some(viewer) => + twitterContext.let(viewer) { + Stitch.run(fetcher.fetch(req.tweet.id, toThrift(req.safetyLevel))) + } + case None => + Stitch.run(fetcher.fetch(req.tweet.id, toThrift(req.safetyLevel))) + } + } + } + } + + parityCheck + .flatMap { parityResponse => + val tvr = TweetVisibilityResultMapper.fromAction(resp.verdict.toActionThrift()) + + parityResponse.v match { + case Some(ptvr) => + if (tvr == ptvr) { + equal.incr() + } else { + incorrect.incr() + } + + case None => + empty.incr() + } + + Stitch.Done + }.rescue { + case t: Throwable => + failures.incr() + Stitch.Done + + }.unit + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/TweetVisibilityRequest.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/TweetVisibilityRequest.scala new file mode 100644 index 000000000..d10b28082 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/TweetVisibilityRequest.scala @@ -0,0 +1,14 @@ +package com.twitter.visibility.interfaces.tweets + +import com.twitter.tweetypie.thriftscala.Tweet +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.ViewerContext + +case class TweetVisibilityRequest( + tweet: Tweet, + safetyLevel: SafetyLevel, + viewerContext: ViewerContext, + isInnerQuotedTweet: Boolean, + isRetweet: Boolean, + hydrateConversationControl: Boolean = false, + isSourceTweet: Boolean = false) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/TweetypieContext.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/TweetypieContext.scala new file mode 100644 index 000000000..e7178bb20 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/TweetypieContext.scala @@ -0,0 +1,59 @@ +package com.twitter.visibility.interfaces.tweets + +import com.twitter.finagle.context.Contexts +import com.twitter.io.Buf +import com.twitter.io.BufByteWriter +import com.twitter.io.ByteReader +import com.twitter.util.Future +import com.twitter.util.Return +import com.twitter.util.Throw +import com.twitter.util.Try + +case class TweetypieContext( + isQuotedTweet: Boolean, + isRetweet: Boolean, + hydrateConversationControl: Boolean) + +object TweetypieContext { + + def let[U](value: TweetypieContext)(f: => Future[U]): Future[U] = + Contexts.broadcast.let(TweetypieContextKey, value)(f) + + def get(): Option[TweetypieContext] = + Contexts.broadcast.get(TweetypieContextKey) +} + +object TweetypieContextKey + extends Contexts.broadcast.Key[TweetypieContext]( + "com.twitter.visibility.interfaces.tweets.TweetypieContext" + ) { + + override def marshal(value: TweetypieContext): Buf = { + val bw = BufByteWriter.fixed(1) + val byte = + ((if (value.isQuotedTweet) 1 else 0) << 0) | + ((if (value.isRetweet) 1 else 0) << 1) | + ((if (value.hydrateConversationControl) 1 else 0) << 2) + bw.writeByte(byte) + bw.owned() + } + + override def tryUnmarshal(buf: Buf): Try[TweetypieContext] = { + if (buf.length != 1) { + Throw( + new IllegalArgumentException( + s"Could not extract Boolean from Buf. Length ${buf.length} but required 1" + ) + ) + } else { + val byte: Byte = ByteReader(buf).readByte() + Return( + TweetypieContext( + isQuotedTweet = ((byte & 1) == 1), + isRetweet = ((byte & 2) == 2), + hydrateConversationControl = ((byte & 4) == 4) + ) + ) + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/UserUnavailableStateVisibilityLibrary.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/UserUnavailableStateVisibilityLibrary.scala new file mode 100644 index 000000000..959c76812 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/UserUnavailableStateVisibilityLibrary.scala @@ -0,0 +1,138 @@ +package com.twitter.visibility.interfaces.tweets + +import com.twitter.decider.Decider +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.stitch.Stitch +import com.twitter.visibility.VisibilityLibrary +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.builder.users.UserUnavailableFeatures +import com.twitter.visibility.common.actions.converter.scala.DropReasonConverter +import com.twitter.visibility.configapi.configs.VisibilityDeciderGates +import com.twitter.visibility.features.TweetIsInnerQuotedTweet +import com.twitter.visibility.features.TweetIsRetweet +import com.twitter.visibility.generators.LocalizedInterstitialGenerator +import com.twitter.visibility.generators.TombstoneGenerator +import com.twitter.visibility.models.ContentId.UserUnavailableState +import com.twitter.visibility.models.UserUnavailableStateEnum +import com.twitter.visibility.rules.Drop +import com.twitter.visibility.rules.Interstitial +import com.twitter.visibility.rules.Reason +import com.twitter.visibility.rules.Tombstone +import com.twitter.visibility.thriftscala.UserVisibilityResult + +object UserUnavailableStateVisibilityLibrary { + type Type = UserUnavailableStateVisibilityRequest => Stitch[VisibilityResult] + + def apply( + visibilityLibrary: VisibilityLibrary, + decider: Decider, + tombstoneGenerator: TombstoneGenerator, + interstitialGenerator: LocalizedInterstitialGenerator + ): Type = { + val libraryStatsReceiver = visibilityLibrary.statsReceiver.scope("user_unavailable_vis_library") + val defaultDropScope = visibilityLibrary.statsReceiver.scope("default_drop") + val vfEngineCounter = libraryStatsReceiver.counter("vf_engine_requests") + + val userUnavailableFeatures = UserUnavailableFeatures(libraryStatsReceiver) + val visibilityDeciderGates = VisibilityDeciderGates(decider) + + { r: UserUnavailableStateVisibilityRequest => + vfEngineCounter.incr() + val contentId = UserUnavailableState(r.tweetId) + + val featureMap = + visibilityLibrary.featureMapBuilder( + Seq( + _.withConstantFeature(TweetIsInnerQuotedTweet, r.isInnerQuotedTweet), + _.withConstantFeature(TweetIsRetweet, r.isRetweet), + userUnavailableFeatures.forState(r.userUnavailableState) + ) + ) + + val language = r.viewerContext.requestLanguageCode.getOrElse("en") + + val reason = visibilityLibrary + .runRuleEngine( + contentId, + featureMap, + r.viewerContext, + r.safetyLevel + ).map(defaultToDrop(r.userUnavailableState, defaultDropScope)) + .map(tombstoneGenerator(_, language)) + .map(visibilityResult => { + if (visibilityDeciderGates.enableLocalizedInterstitialInUserStateLibrary()) { + interstitialGenerator(visibilityResult, language) + } else { + visibilityResult + } + }) + + reason + } + } + + def defaultToDrop( + userUnavailableState: UserUnavailableStateEnum, + defaultDropScope: StatsReceiver + )( + result: VisibilityResult + ): VisibilityResult = + result.verdict match { + case _: Drop | _: Tombstone => result + + case _: Interstitial => result + case _ => + result.copy(verdict = + Drop(userUnavailableStateToDropReason(userUnavailableState, defaultDropScope))) + } + + private[this] def userUnavailableStateToDropReason( + userUnavailableState: UserUnavailableStateEnum, + stats: StatsReceiver + ): Reason = + userUnavailableState match { + case UserUnavailableStateEnum.Erased => + stats.counter("erased").incr() + Reason.ErasedAuthor + case UserUnavailableStateEnum.Protected => + stats.counter("protected").incr() + Reason.ProtectedAuthor + case UserUnavailableStateEnum.Offboarded => + stats.counter("offboarded").incr() + Reason.OffboardedAuthor + case UserUnavailableStateEnum.AuthorBlocksViewer => + stats.counter("author_blocks_viewer").incr() + Reason.AuthorBlocksViewer + case UserUnavailableStateEnum.Suspended => + stats.counter("suspended_author").incr() + Reason.SuspendedAuthor + case UserUnavailableStateEnum.Deactivated => + stats.counter("deactivated_author").incr() + Reason.DeactivatedAuthor + case UserUnavailableStateEnum.Filtered(result) => + stats.counter("filtered").incr() + userVisibilityResultToDropReason(result, stats.scope("filtered")) + case UserUnavailableStateEnum.Unavailable => + stats.counter("unspecified").incr() + Reason.Unspecified + case _ => + stats.counter("unknown").incr() + stats.scope("unknown").counter(userUnavailableState.name).incr() + Reason.Unspecified + } + + private[this] def userVisibilityResultToDropReason( + result: UserVisibilityResult, + stats: StatsReceiver + ): Reason = + result.action + .flatMap(DropReasonConverter.fromAction) + .map { dropReason => + val reason = Reason.fromDropReason(dropReason) + stats.counter(reason.name).incr() + reason + }.getOrElse { + stats.counter("empty") + Reason.Unspecified + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/UserUnavailableStateVisibilityRequest.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/UserUnavailableStateVisibilityRequest.scala new file mode 100644 index 000000000..b419859f6 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/UserUnavailableStateVisibilityRequest.scala @@ -0,0 +1,14 @@ +package com.twitter.visibility.interfaces.tweets + +import com.twitter.visibility.models.UserUnavailableStateEnum +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.ViewerContext + +case class UserUnavailableStateVisibilityRequest( + safetyLevel: SafetyLevel, + tweetId: Long, + viewerContext: ViewerContext, + userUnavailableState: UserUnavailableStateEnum, + isRetweet: Boolean, + isInnerQuotedTweet: Boolean, +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/enrichments/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/enrichments/BUILD new file mode 100644 index 000000000..fcf16cd5a --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/enrichments/BUILD @@ -0,0 +1,25 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "featureswitches/featureswitches-core/src/main/scala", + "src/thrift/com/twitter/spam/rtf:safety-result-scala", + "stitch/stitch-core", + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/common/src/main/scala/com/twitter/visibility/common/actions", + "visibility/common/src/main/scala/com/twitter/visibility/common/actions/converter/scala", + "visibility/lib/src/main/scala/com/twitter/visibility", + "visibility/lib/src/main/scala/com/twitter/visibility/builder", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/tweets", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/rules", + "visibility/results/src/main/scala/com/twitter/visibility/results/richtext", + ], + exports = [ + "featureswitches/featureswitches-core/src/main/scala", + "visibility/lib/src/main/scala/com/twitter/visibility/builder", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/enrichments/ComplianceTweetNoticeEnrichment.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/enrichments/ComplianceTweetNoticeEnrichment.scala new file mode 100644 index 000000000..a38459068 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/enrichments/ComplianceTweetNoticeEnrichment.scala @@ -0,0 +1,55 @@ +package com.twitter.visibility.interfaces.tweets.enrichments + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.results.richtext.PublicInterestReasonToPlainText +import com.twitter.visibility.rules.Action +import com.twitter.visibility.rules.ComplianceTweetNoticePreEnrichment +import com.twitter.visibility.rules.PublicInterest +import com.twitter.visibility.rules.Reason + +object ComplianceTweetNoticeEnrichment { + val ComplianceTweetNoticeEnrichmentScope = "compliance_tweet_notice_enrichment" + val ComplianceTweetNoticePreEnrichmentActionScope = + "compliance_tweet_notice_pre_enrichment_action" + + val englishLanguageTag = "en" + + def apply(result: VisibilityResult, statsReceiver: StatsReceiver): VisibilityResult = { + val scopedStatsReceiver = statsReceiver.scope(ComplianceTweetNoticeEnrichmentScope) + + val enrichedVerdict = enrichVerdict(result.verdict, scopedStatsReceiver) + result.copy(verdict = enrichedVerdict) + } + + private def enrichVerdict( + verdict: Action, + statsReceiver: StatsReceiver + ): Action = { + val preEnrichmentActionScope = + statsReceiver.scope(ComplianceTweetNoticePreEnrichmentActionScope) + + verdict match { + case complianceTweetNoticePreEnrichmentVerdict: ComplianceTweetNoticePreEnrichment => + preEnrichmentActionScope.counter("").incr() + + val verdictWithDetailsAndUrl = complianceTweetNoticePreEnrichmentVerdict.reason match { + case Reason.Unspecified => + preEnrichmentActionScope.counter("reason_unspecified").incr() + complianceTweetNoticePreEnrichmentVerdict + + case reason => + preEnrichmentActionScope.counter("reason_specified").incr() + val safetyResultReason = PublicInterest.ReasonToSafetyResultReason(reason) + val (details, url) = + PublicInterestReasonToPlainText(safetyResultReason, englishLanguageTag) + complianceTweetNoticePreEnrichmentVerdict.copy( + details = Some(details), + extendedDetailsUrl = Some(url)) + } + verdictWithDetailsAndUrl.toComplianceTweetNotice() + + case _ => verdict + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/enrichments/LimitedActionsPolicyEnrichment.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/enrichments/LimitedActionsPolicyEnrichment.scala new file mode 100644 index 000000000..62eda043d --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/enrichments/LimitedActionsPolicyEnrichment.scala @@ -0,0 +1,173 @@ +package com.twitter.visibility.interfaces.tweets.enrichments + +import com.twitter.featureswitches.FSRecipient +import com.twitter.featureswitches.v2.FeatureSwitches +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.common.LocalizedLimitedActionsSource +import com.twitter.visibility.common.actions.converter.scala.LimitedActionTypeConverter +import com.twitter.visibility.common.actions.LimitedActionsPolicy +import com.twitter.visibility.common.actions.LimitedActionType +import com.twitter.visibility.common.actions.LimitedEngagementReason +import com.twitter.visibility.rules.Action +import com.twitter.visibility.rules.EmergencyDynamicInterstitial +import com.twitter.visibility.rules.InterstitialLimitedEngagements +import com.twitter.visibility.rules.LimitedEngagements + +case class PolicyFeatureSwitchResults( + limitedActionTypes: Option[Seq[LimitedActionType]], + copyNamespace: String, + promptType: String, + learnMoreUrl: Option[String]) + +object LimitedActionsPolicyEnrichment { + object FeatureSwitchKeys { + val LimitedActionTypes = "limited_actions_policy_limited_actions" + val CopyNamespace = "limited_actions_policy_copy_namespace" + val PromptType = "limited_actions_policy_prompt_type" + val LearnMoreUrl = "limited_actions_policy_prompt_learn_more_url" + } + + val DefaultCopyNameSpace = "Default" + val DefaultPromptType = "basic" + val LimitedActionsPolicyEnrichmentScope = "limited_actions_policy_enrichment" + val MissingLimitedActionTypesScope = "missing_limited_action_types" + val ExecutedScope = "executed" + + def apply( + result: VisibilityResult, + localizedLimitedActionSource: LocalizedLimitedActionsSource, + languageCode: String, + countryCode: Option[String], + featureSwitches: FeatureSwitches, + statsReceiver: StatsReceiver + ): VisibilityResult = { + val scopedStatsReceiver = statsReceiver.scope(LimitedActionsPolicyEnrichmentScope) + + val enrichVerdict_ = enrichVerdict( + _: Action, + localizedLimitedActionSource, + languageCode, + countryCode, + featureSwitches, + scopedStatsReceiver + ) + + result.copy( + verdict = enrichVerdict_(result.verdict), + secondaryVerdicts = result.secondaryVerdicts.map(enrichVerdict_) + ) + } + + private def enrichVerdict( + verdict: Action, + localizedLimitedActionsSource: LocalizedLimitedActionsSource, + languageCode: String, + countryCode: Option[String], + featureSwitches: FeatureSwitches, + statsReceiver: StatsReceiver + ): Action = { + val limitedActionsPolicyForReason_ = limitedActionsPolicyForReason( + _: LimitedEngagementReason, + localizedLimitedActionsSource, + languageCode, + countryCode, + featureSwitches, + statsReceiver + ) + val executedCounter = statsReceiver.scope(ExecutedScope) + + verdict match { + case le: LimitedEngagements => { + executedCounter.counter("").incr() + executedCounter.counter(le.name).incr() + le.copy( + policy = limitedActionsPolicyForReason_(le.getLimitedEngagementReason) + ) + } + case ile: InterstitialLimitedEngagements => { + executedCounter.counter("").incr() + executedCounter.counter(ile.name).incr() + ile.copy( + policy = limitedActionsPolicyForReason_( + ile.getLimitedEngagementReason + ) + ) + } + case edi: EmergencyDynamicInterstitial => { + executedCounter.counter("").incr() + executedCounter.counter(edi.name).incr() + EmergencyDynamicInterstitial( + copy = edi.copy, + linkOpt = edi.linkOpt, + localizedMessage = edi.localizedMessage, + policy = limitedActionsPolicyForReason_(edi.getLimitedEngagementReason) + ) + } + case _ => verdict + } + } + + private def limitedActionsPolicyForReason( + reason: LimitedEngagementReason, + localizedLimitedActionsSource: LocalizedLimitedActionsSource, + languageCode: String, + countryCode: Option[String], + featureSwitches: FeatureSwitches, + statsReceiver: StatsReceiver + ): Option[LimitedActionsPolicy] = { + val policyConfig = getPolicyFeatureSwitchResults(featureSwitches, reason) + + policyConfig.limitedActionTypes match { + case Some(limitedActionTypes) if limitedActionTypes.nonEmpty => + Some( + LimitedActionsPolicy( + limitedActionTypes.map( + localizedLimitedActionsSource.fetch( + _, + languageCode, + countryCode, + policyConfig.promptType, + policyConfig.copyNamespace, + policyConfig.learnMoreUrl + ) + ) + ) + ) + case _ => { + statsReceiver + .scope(MissingLimitedActionTypesScope).counter(reason.toLimitedActionsString).incr() + None + } + } + } + + private def getPolicyFeatureSwitchResults( + featureSwitches: FeatureSwitches, + reason: LimitedEngagementReason + ): PolicyFeatureSwitchResults = { + val recipient = FSRecipient().withCustomFields( + ("LimitedEngagementReason", reason.toLimitedActionsString) + ) + val featureSwitchesResults = featureSwitches + .matchRecipient(recipient) + + val limitedActionTypes = featureSwitchesResults + .getStringArray(FeatureSwitchKeys.LimitedActionTypes) + .map(_.map(LimitedActionTypeConverter.fromString).flatten) + + val copyNamespace = featureSwitchesResults + .getString(FeatureSwitchKeys.CopyNamespace) + .getOrElse(DefaultCopyNameSpace) + + val promptType = featureSwitchesResults + .getString(FeatureSwitchKeys.PromptType) + .getOrElse(DefaultPromptType) + + val learnMoreUrl = featureSwitchesResults + .getString(FeatureSwitchKeys.LearnMoreUrl) + .filter(_.nonEmpty) + + PolicyFeatureSwitchResults(limitedActionTypes, copyNamespace, promptType, learnMoreUrl) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/enrichments/TweetVisibilityNudgeEnrichment.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/enrichments/TweetVisibilityNudgeEnrichment.scala new file mode 100644 index 000000000..e8f513848 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/tweets/enrichments/TweetVisibilityNudgeEnrichment.scala @@ -0,0 +1,96 @@ +package com.twitter.visibility.interfaces.tweets.enrichments + +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.builder.tweets.TweetVisibilityNudgeSourceWrapper +import com.twitter.visibility.common.actions.TweetVisibilityNudgeReason.SemanticCoreMisinformationLabelReason +import com.twitter.visibility.rules.Action +import com.twitter.visibility.rules.LocalizedNudge +import com.twitter.visibility.rules.SoftIntervention +import com.twitter.visibility.rules.TweetVisibilityNudge + +object TweetVisibilityNudgeEnrichment { + + def apply( + result: VisibilityResult, + tweetVisibilityNudgeSourceWrapper: TweetVisibilityNudgeSourceWrapper, + languageCode: String, + countryCode: Option[String] + ): VisibilityResult = { + + val softIntervention = extractSoftIntervention(result.verdict, result.secondaryVerdicts) + + val enrichedPrimaryVerdict = enrichAction( + result.verdict, + tweetVisibilityNudgeSourceWrapper, + softIntervention, + languageCode, + countryCode) + + val enrichedSecondaryVerdicts: Seq[Action] = + result.secondaryVerdicts.map(sv => + enrichAction( + sv, + tweetVisibilityNudgeSourceWrapper, + softIntervention, + languageCode, + countryCode)) + + result.copy(verdict = enrichedPrimaryVerdict, secondaryVerdicts = enrichedSecondaryVerdicts) + } + + private def extractSoftIntervention( + primary: Action, + secondaries: Seq[Action] + ): Option[SoftIntervention] = { + primary match { + case si: SoftIntervention => Some(si) + case _ => + secondaries.collectFirst { + case sv: SoftIntervention => sv + } + } + } + + private def enrichAction( + action: Action, + tweetVisibilityNudgeSourceWrapper: TweetVisibilityNudgeSourceWrapper, + softIntervention: Option[SoftIntervention], + languageCode: String, + countryCode: Option[String] + ): Action = { + action match { + case TweetVisibilityNudge(reason, None) => + val localizedNudge = + tweetVisibilityNudgeSourceWrapper.getLocalizedNudge(reason, languageCode, countryCode) + if (reason == SemanticCoreMisinformationLabelReason) + TweetVisibilityNudge( + reason, + enrichLocalizedMisInfoNudge(localizedNudge, softIntervention)) + else + TweetVisibilityNudge(reason, localizedNudge) + case _ => action + } + } + + private def enrichLocalizedMisInfoNudge( + localizedNudge: Option[LocalizedNudge], + softIntervention: Option[SoftIntervention] + ): Option[LocalizedNudge] = { + softIntervention match { + case Some(si) => { + val enrichedLocalizedNudge = localizedNudge.map { ln => + val enrichedLocalizedNudgeActions = ln.localizedNudgeActions.map { na => + val enrichedPayload = na.nudgeActionPayload.map { payload => + payload.copy(ctaUrl = si.detailsUrl, heading = si.warning) + } + na.copy(nudgeActionPayload = enrichedPayload) + } + ln.copy(localizedNudgeActions = enrichedLocalizedNudgeActions) + } + enrichedLocalizedNudge + } + case None => localizedNudge + } + } + +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/users/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/users/BUILD new file mode 100644 index 000000000..5bdeef86e --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/users/BUILD @@ -0,0 +1,21 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "decider/src/main/scala", + "servo/decider/src/main/scala", + "src/thrift/com/twitter/gizmoduck:user-thrift-scala", + "stitch/stitch-core", + "strato/src/main/scala/com/twitter/strato/client", + "visibility/common/src/main/scala/com/twitter/visibility/common", + "visibility/lib/src/main/resources/config", + "visibility/lib/src/main/scala/com/twitter/visibility", + "visibility/lib/src/main/scala/com/twitter/visibility/builder/users", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/configs", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/thrift/com/twitter/visibility/context:vf-context-scala", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/users/UserVisibilityLibrary.scala b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/users/UserVisibilityLibrary.scala new file mode 100644 index 000000000..1c8801939 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/interfaces/users/UserVisibilityLibrary.scala @@ -0,0 +1,111 @@ +package com.twitter.visibility.interfaces.users + +import com.twitter.decider.Decider +import com.twitter.gizmoduck.thriftscala.User +import com.twitter.servo.decider.DeciderGateBuilder +import com.twitter.stitch.Stitch +import com.twitter.strato.client.Client +import com.twitter.visibility.VisibilityLibrary +import com.twitter.visibility.builder.users.AuthorFeatures +import com.twitter.visibility.builder.users.RelationshipFeatures +import com.twitter.visibility.builder.users.ViewerAdvancedFilteringFeatures +import com.twitter.visibility.builder.users.ViewerFeatures +import com.twitter.visibility.builder.users.ViewerSearchSafetyFeatures +import com.twitter.visibility.builder.VisibilityResult +import com.twitter.visibility.builder.users.SearchFeatures +import com.twitter.visibility.common.UserRelationshipSource +import com.twitter.visibility.common.UserSearchSafetySource +import com.twitter.visibility.common.UserSource +import com.twitter.visibility.configapi.configs.VisibilityDeciderGates +import com.twitter.visibility.context.thriftscala.UserVisibilityFilteringContext +import com.twitter.visibility.models.ContentId.UserId +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.ViewerContext +import com.twitter.visibility.rules.Reason.Unspecified +import com.twitter.visibility.rules.Allow +import com.twitter.visibility.rules.Drop +import com.twitter.visibility.rules.RuleBase + +object UserVisibilityLibrary { + type Type = + (User, SafetyLevel, ViewerContext, UserVisibilityFilteringContext) => Stitch[VisibilityResult] + + def apply( + visibilityLibrary: VisibilityLibrary, + userSource: UserSource = UserSource.empty, + userRelationshipSource: UserRelationshipSource = UserRelationshipSource.empty, + stratoClient: Client, + decider: Decider + ): Type = { + val libraryStatsReceiver = visibilityLibrary.statsReceiver.scope("user_library") + val stratoClientStatsReceiver = visibilityLibrary.statsReceiver.scope("strato") + + val visibilityDeciderGates = VisibilityDeciderGates(decider) + + val vfEngineCounter = libraryStatsReceiver.counter("vf_engine_requests") + val noUserRulesCounter = libraryStatsReceiver.counter("no_user_rules_requests") + val viewerIsAuthorCounter = libraryStatsReceiver.counter("viewer_is_author_requests") + + val authorFeatures = new AuthorFeatures(userSource, libraryStatsReceiver) + val viewerFeatures = new ViewerFeatures(userSource, libraryStatsReceiver) + val relationshipFeatures = + new RelationshipFeatures(userRelationshipSource, libraryStatsReceiver) + val searchFeatures = new SearchFeatures(libraryStatsReceiver) + + val viewerSafeSearchFeatures = new ViewerSearchSafetyFeatures( + UserSearchSafetySource.fromStrato(stratoClient, stratoClientStatsReceiver), + libraryStatsReceiver) + + val deciderGateBuilder = new DeciderGateBuilder(decider) + val advancedFilteringFeatures = + new ViewerAdvancedFilteringFeatures(userSource, libraryStatsReceiver) + + (user, safetyLevel, viewerContext, userVisibilityFilteringContext) => { + val contentId = UserId(user.id) + val viewerId = viewerContext.userId + + if (!RuleBase.hasUserRules(safetyLevel)) { + noUserRulesCounter.incr() + Stitch.value(VisibilityResult(contentId = contentId, verdict = Allow)) + } else { + if (viewerId.contains(user.id)) { + viewerIsAuthorCounter.incr() + + Stitch.value(VisibilityResult(contentId = contentId, verdict = Allow)) + } else { + vfEngineCounter.incr() + + val featureMap = + visibilityLibrary.featureMapBuilder( + Seq( + viewerFeatures.forViewerContext(viewerContext), + viewerSafeSearchFeatures.forViewerId(viewerId), + relationshipFeatures.forAuthor(user, viewerId), + authorFeatures.forAuthor(user), + advancedFilteringFeatures.forViewerId(viewerId), + searchFeatures.forSearchContext(userVisibilityFilteringContext.searchContext) + ) + ) + + visibilityLibrary.runRuleEngine( + contentId, + featureMap, + viewerContext, + safetyLevel + ) + + } + } + } + } + + def Const(shouldDrop: Boolean): Type = + (user, _, _, _) => + Stitch.value( + VisibilityResult( + contentId = UserId(user.id), + verdict = if (shouldDrop) Drop(Unspecified) else Allow, + finished = true + ) + ) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/models/BUILD new file mode 100644 index 000000000..847043fc6 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/BUILD @@ -0,0 +1,32 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "configapi/configapi-core", + "datatools/src/thrift/com/twitter/datatools/entityservice:entity-entities-thrift-scala", + "escherbird/src/thrift/com/twitter/escherbird/softintervention:softintervention_thrift-scala", + "featureswitches/featureswitches-core", + "finatra-internal/request/src/main/scala", + "src/thrift/com/twitter/content-health/sensitivemediasettings:sensitivemediasettings-scala", + "src/thrift/com/twitter/context:twitter-context-scala", + "src/thrift/com/twitter/escherbird:tweet-annotation-scala", + "src/thrift/com/twitter/escherbird/common:common-scala", + "src/thrift/com/twitter/gizmoduck:user-thrift-scala", + "src/thrift/com/twitter/spam/rtf:safety-level-scala", + "src/thrift/com/twitter/spam/rtf:safety-result-scala", + "src/thrift/com/twitter/spam/rtf:tweet-rtf-event-scala", + "src/thrift/com/twitter/tweetypie:tweet-scala", + "timelines/src/main/scala/com/twitter/timelines/model/candidate", + "timelines/src/main/scala/com/twitter/timelines/util/client_info", + "twitter-context/src/main/scala", + "visibility/common/src/main/scala/com/twitter/visibility/common/actions", + "visibility/common/src/main/thrift/com/twitter/visibility:action-scala", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/params", + "visibility/lib/src/main/scala/com/twitter/visibility/util", + "visibility/lib/src/main/thrift/com/twitter/visibility/safety_label_store:safety-label-store-scala", + "visibility/lib/src/main/thrift/com/twitter/visibility/strato:vf-strato-scala", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/CommunityTweet.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/CommunityTweet.scala new file mode 100644 index 000000000..fbb4dc6ac --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/CommunityTweet.scala @@ -0,0 +1,23 @@ +package com.twitter.visibility.models + +import com.twitter.tweetypie.thriftscala.Communities +import com.twitter.tweetypie.thriftscala.Tweet + +object CommunityTweet { + def getCommunityId(communities: Communities): Option[CommunityId] = + communities.communityIds.headOption + + def getCommunityId(tweet: Tweet): Option[CommunityId] = + tweet.communities.flatMap(getCommunityId) + + def apply(tweet: Tweet): Option[CommunityTweet] = + getCommunityId(tweet).map { communityId => + val authorId = tweet.coreData.get.userId + CommunityTweet(tweet, communityId, authorId) + } +} + +case class CommunityTweet( + tweet: Tweet, + communityId: CommunityId, + authorId: Long) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/ContentId.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/ContentId.scala new file mode 100644 index 000000000..b08e2b9ca --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/ContentId.scala @@ -0,0 +1,22 @@ +package com.twitter.visibility.models + +sealed trait ContentId + +object ContentId { + case class TweetId(id: Long) extends ContentId + case class UserId(id: Long) extends ContentId + case class CardId(url: String) extends ContentId + case class QuotedTweetRelationship(outer: Long, inner: Long) extends ContentId + case class NotificationId(tweetId: Option[Long]) extends ContentId + case class DmId(id: Long) extends ContentId + case class BlenderTweetId(id: Long) extends ContentId + case class SpaceId(id: String) extends ContentId + case class SpacePlusUserId(id: String) extends ContentId + case class DmConversationId(id: String) extends ContentId + case class DmEventId(id: Long) extends ContentId + case class UserUnavailableState(tweetId: Long) extends ContentId + case class TwitterArticleId(id: Long) extends ContentId + case class DeleteTweetId(tweetId: Long) extends ContentId + case class MediaId(id: String) extends ContentId + case class CommunityId(communityId: Long) extends ContentId +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/LabelSource.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/LabelSource.scala new file mode 100644 index 000000000..0a85a024d --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/LabelSource.scala @@ -0,0 +1,61 @@ +package com.twitter.visibility.models + +import com.twitter.spam.rtf.thriftscala.SafetyResultReason +import java.util.regex.Pattern + +sealed trait LabelSource { + val name: String +} + +object LabelSource { + val BotRulePrefix = "bot_id_" + val AbusePrefix = "Abuse" + val HSEPrefix = "hse" + val AgentSourceNames = Set( + SafetyResultReason.OneOff.name, + SafetyResultReason.VotingMisinformation.name, + SafetyResultReason.HackedMaterials.name, + SafetyResultReason.Scams.name, + SafetyResultReason.PlatformManipulation.name + ) + + val Regex = "\\|" + val pattern: Pattern = Pattern.compile(Regex) + + def fromString(name: String): Option[LabelSource] = Some(name) collect { + case _ if name.startsWith(BotRulePrefix) => + BotMakerRule(name.substring(BotRulePrefix.length).toLong) + case _ if name == "A" || name == "B" || name == "AB" => + SmyteSource(name) + case _ if name.startsWith(AbusePrefix) => + AbuseSource(name) + case _ if name.startsWith(HSEPrefix) => + HSESource(name) + case _ if AgentSourceNames.contains(name) => + AgentSource(name) + case _ => + StringSource(name) + } + + def parseStringSource(source: String): (String, Option[String]) = { + pattern.split(source, 2) match { + case Array(copy, "") => (copy, None) + case Array(copy, link) => (copy, Some(link)) + case Array(copy) => (copy, None) + } + } + + case class BotMakerRule(ruleId: Long) extends LabelSource { + override lazy val name: String = s"${BotRulePrefix}${ruleId}" + } + + case class SmyteSource(name: String) extends LabelSource + + case class AbuseSource(name: String) extends LabelSource + + case class AgentSource(name: String) extends LabelSource + + case class HSESource(name: String) extends LabelSource + + case class StringSource(name: String) extends LabelSource +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/MediaSafetyLabelType.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/MediaSafetyLabelType.scala new file mode 100644 index 000000000..18df9d511 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/MediaSafetyLabelType.scala @@ -0,0 +1,89 @@ +package com.twitter.visibility.models + +import com.twitter.visibility.safety_label_store.{thriftscala => s} +import com.twitter.visibility.util.NamingUtils + +sealed trait MediaSafetyLabelType extends SafetyLabelType { + lazy val name: String = NamingUtils.getFriendlyName(this) +} + +object MediaSafetyLabelType extends SafetyLabelType { + + val List: List[MediaSafetyLabelType] = s.MediaSafetyLabelType.list.map(fromThrift) + + val ActiveLabels: List[MediaSafetyLabelType] = List.filter { labelType => + labelType != Unknown && labelType != Deprecated + } + + private lazy val nameToValueMap: Map[String, MediaSafetyLabelType] = + List.map(l => l.name.toLowerCase -> l).toMap + def fromName(name: String): Option[MediaSafetyLabelType] = nameToValueMap.get(name.toLowerCase) + + private val UnknownThriftSafetyLabelType = + s.MediaSafetyLabelType.EnumUnknownMediaSafetyLabelType(UnknownEnumValue) + + private lazy val thriftToModelMap: Map[s.MediaSafetyLabelType, MediaSafetyLabelType] = Map( + s.MediaSafetyLabelType.NsfwHighPrecision -> NsfwHighPrecision, + s.MediaSafetyLabelType.NsfwHighRecall -> NsfwHighRecall, + s.MediaSafetyLabelType.NsfwNearPerfect -> NsfwNearPerfect, + s.MediaSafetyLabelType.NsfwCardImage -> NsfwCardImage, + s.MediaSafetyLabelType.Pdna -> Pdna, + s.MediaSafetyLabelType.PdnaNoTreatmentIfVerified -> PdnaNoTreatmentIfVerified, + s.MediaSafetyLabelType.DmcaWithheld -> DmcaWithheld, + s.MediaSafetyLabelType.LegalDemandsWithheld -> LegalDemandsWithheld, + s.MediaSafetyLabelType.LocalLawsWithheld -> LocalLawsWithheld, + s.MediaSafetyLabelType.Reserved10 -> Deprecated, + s.MediaSafetyLabelType.Reserved11 -> Deprecated, + s.MediaSafetyLabelType.Reserved12 -> Deprecated, + s.MediaSafetyLabelType.Reserved13 -> Deprecated, + s.MediaSafetyLabelType.Reserved14 -> Deprecated, + s.MediaSafetyLabelType.Reserved15 -> Deprecated, + s.MediaSafetyLabelType.Reserved16 -> Deprecated, + s.MediaSafetyLabelType.Reserved17 -> Deprecated, + s.MediaSafetyLabelType.Reserved18 -> Deprecated, + s.MediaSafetyLabelType.Reserved19 -> Deprecated, + s.MediaSafetyLabelType.Reserved20 -> Deprecated, + s.MediaSafetyLabelType.Reserved21 -> Deprecated, + s.MediaSafetyLabelType.Reserved22 -> Deprecated, + s.MediaSafetyLabelType.Reserved23 -> Deprecated, + s.MediaSafetyLabelType.Reserved24 -> Deprecated, + s.MediaSafetyLabelType.Reserved25 -> Deprecated, + s.MediaSafetyLabelType.Reserved26 -> Deprecated, + s.MediaSafetyLabelType.Reserved27 -> Deprecated, + ) + + private lazy val modelToThriftMap: Map[MediaSafetyLabelType, s.MediaSafetyLabelType] = + (for ((k, v) <- thriftToModelMap) yield (v, k)) ++ Map( + Deprecated -> s.MediaSafetyLabelType.EnumUnknownMediaSafetyLabelType(DeprecatedEnumValue), + ) + + case object NsfwHighPrecision extends MediaSafetyLabelType + case object NsfwHighRecall extends MediaSafetyLabelType + case object NsfwNearPerfect extends MediaSafetyLabelType + case object NsfwCardImage extends MediaSafetyLabelType + case object Pdna extends MediaSafetyLabelType + case object PdnaNoTreatmentIfVerified extends MediaSafetyLabelType + case object DmcaWithheld extends MediaSafetyLabelType + case object LegalDemandsWithheld extends MediaSafetyLabelType + case object LocalLawsWithheld extends MediaSafetyLabelType + + case object Deprecated extends MediaSafetyLabelType + case object Unknown extends MediaSafetyLabelType + + def fromThrift(safetyLabelType: s.MediaSafetyLabelType): MediaSafetyLabelType = + thriftToModelMap.get(safetyLabelType) match { + case Some(mediaSafetyLabelType) => mediaSafetyLabelType + case _ => + safetyLabelType match { + case s.MediaSafetyLabelType.EnumUnknownMediaSafetyLabelType(DeprecatedEnumValue) => + Deprecated + case _ => + Unknown + } + } + + def toThrift(safetyLabelType: MediaSafetyLabelType): s.MediaSafetyLabelType = { + modelToThriftMap + .get(safetyLabelType).getOrElse(UnknownThriftSafetyLabelType) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/MisinformationPolicy.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/MisinformationPolicy.scala new file mode 100644 index 000000000..7ce381bfc --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/MisinformationPolicy.scala @@ -0,0 +1,179 @@ +package com.twitter.visibility.models + +import com.twitter.datatools.entityservice.entities.thriftscala.FleetInterstitial +import com.twitter.datatools.entityservice.entities.{thriftscala => t} +import com.twitter.escherbird.softintervention.thriftscala.MisinformationLocalizedPolicy +import com.twitter.escherbird.thriftscala.TweetEntityAnnotation + +case class MisinformationPolicy( + semanticCoreAnnotation: SemanticCoreAnnotation, + priority: Long = MisinformationPolicy.DefaultPriority, + filteringLevel: Int = MisinformationPolicy.DefaultFilteringLevel, + publishedState: PublishedState = MisinformationPolicy.DefaultPublishedState, + engagementNudge: Boolean = MisinformationPolicy.DefaultEngagementNudge, + suppressAutoplay: Boolean = MisinformationPolicy.DefaultSuppressAutoplay, + warning: Option[String] = None, + detailsUrl: Option[String] = None, + displayType: Option[MisinfoPolicyDisplayType] = None, + applicableCountries: Seq[String] = Seq.empty, + fleetInterstitial: Option[FleetInterstitial] = None) + +object MisinformationPolicy { + private val DefaultPriority = 0 + private val DefaultFilteringLevel = 1 + private val DefaultPublishedState = PublishedState.Published + private val DefaultEngagementNudge = true + private val DefaultSuppressAutoplay = true + + def apply( + annotation: TweetEntityAnnotation, + misinformation: MisinformationLocalizedPolicy + ): MisinformationPolicy = { + MisinformationPolicy( + semanticCoreAnnotation = SemanticCoreAnnotation( + groupId = annotation.groupId, + domainId = annotation.domainId, + entityId = annotation.entityId + ), + priority = misinformation.priority.getOrElse(DefaultPriority), + filteringLevel = misinformation.filteringLevel.getOrElse(DefaultFilteringLevel), + publishedState = misinformation.publishedState match { + case Some(t.PublishedState.Draft) => PublishedState.Draft + case Some(t.PublishedState.Dogfood) => PublishedState.Dogfood + case Some(t.PublishedState.Published) => PublishedState.Published + case _ => DefaultPublishedState + }, + displayType = misinformation.displayType collect { + case t.MisinformationDisplayType.GetTheLatest => MisinfoPolicyDisplayType.GetTheLatest + case t.MisinformationDisplayType.StayInformed => MisinfoPolicyDisplayType.StayInformed + case t.MisinformationDisplayType.Misleading => MisinfoPolicyDisplayType.Misleading + case t.MisinformationDisplayType.GovernmentRequested => + MisinfoPolicyDisplayType.GovernmentRequested + }, + applicableCountries = misinformation.applicableCountries match { + case Some(countries) => countries.map(countryCode => countryCode.toLowerCase) + case _ => Seq.empty + }, + fleetInterstitial = misinformation.fleetInterstitial, + engagementNudge = misinformation.engagementNudge.getOrElse(DefaultEngagementNudge), + suppressAutoplay = misinformation.suppressAutoplay.getOrElse(DefaultSuppressAutoplay), + warning = misinformation.warning, + detailsUrl = misinformation.detailsUrl, + ) + } +} + +trait MisinformationPolicyTransform { + def apply(policies: Seq[MisinformationPolicy]): Seq[MisinformationPolicy] + def andThen(transform: MisinformationPolicyTransform): MisinformationPolicyTransform = + (policies: Seq[MisinformationPolicy]) => transform(this.apply(policies)) +} + +object MisinformationPolicyTransform { + + def prioritize: MisinformationPolicyTransform = + (policies: Seq[MisinformationPolicy]) => + policies + .sortBy(p => p.filteringLevel)(Ordering[Int].reverse) + .sortBy(p => p.priority)(Ordering[Long].reverse) + + def filter(filters: Seq[MisinformationPolicy => Boolean]): MisinformationPolicyTransform = + (policies: Seq[MisinformationPolicy]) => + policies.filter { policy => filters.forall { filter => filter(policy) } } + + def filterLevelAndState( + filteringLevel: Int, + publishedStates: Seq[PublishedState] + ): MisinformationPolicyTransform = + filter( + Seq( + hasFilteringLevelAtLeast(filteringLevel), + hasPublishedStates(publishedStates) + )) + + def filterLevelAndStateAndLocalized( + filteringLevel: Int, + publishedStates: Seq[PublishedState] + ): MisinformationPolicyTransform = + filter( + Seq( + hasFilteringLevelAtLeast(filteringLevel), + hasPublishedStates(publishedStates), + hasNonEmptyLocalization, + )) + + def filterState( + publishedStates: Seq[PublishedState] + ): MisinformationPolicyTransform = + filter( + Seq( + hasPublishedStates(publishedStates) + )) + + def filterStateAndLocalized( + publishedStates: Seq[PublishedState] + ): MisinformationPolicyTransform = + filter( + Seq( + hasPublishedStates(publishedStates), + hasNonEmptyLocalization, + )) + + def filterApplicableCountries( + countryCode: Option[String], + ): MisinformationPolicyTransform = + filter(Seq(policyAppliesToCountry(countryCode))) + + def filterOutGeoSpecific(): MisinformationPolicyTransform = + filter(Seq(policyIsGlobal())) + + def filterNonEngagementNudges(): MisinformationPolicyTransform = + filter( + Seq( + hasEngagementNudge, + )) + + def policyAppliesToCountry(countryCode: Option[String]): MisinformationPolicy => Boolean = + policy => + policy.applicableCountries.isEmpty || + (countryCode.nonEmpty && policy.applicableCountries.contains(countryCode.get)) + + def policyIsGlobal(): MisinformationPolicy => Boolean = + policy => policy.applicableCountries.isEmpty + + def hasFilteringLevelAtLeast(filteringLevel: Int): MisinformationPolicy => Boolean = + _.filteringLevel >= filteringLevel + + def hasPublishedStates( + publishedStates: Seq[PublishedState] + ): MisinformationPolicy => Boolean = + policy => publishedStates.contains(policy.publishedState) + + def hasNonEmptyLocalization: MisinformationPolicy => Boolean = + policy => policy.warning.nonEmpty && policy.detailsUrl.nonEmpty + + def hasEngagementNudge: MisinformationPolicy => Boolean = + policy => policy.engagementNudge + +} + +sealed trait PublishedState +object PublishedState { + case object Draft extends PublishedState + case object Dogfood extends PublishedState + case object Published extends PublishedState + + val PublicPublishedStates = Seq(PublishedState.Published) + val EmployeePublishedStates = Seq(PublishedState.Published, PublishedState.Dogfood) +} +sealed trait MisinfoPolicyDisplayType +object MisinfoPolicyDisplayType { + case object GetTheLatest extends MisinfoPolicyDisplayType + case object StayInformed extends MisinfoPolicyDisplayType + case object Misleading extends MisinfoPolicyDisplayType + case object GovernmentRequested extends MisinfoPolicyDisplayType +} + +object SemanticCoreMisinformation { + val domainId: Long = 148L +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/MutedKeyword.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/MutedKeyword.scala new file mode 100644 index 000000000..2dcb6f9b2 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/MutedKeyword.scala @@ -0,0 +1,3 @@ +package com.twitter.visibility.models + +case class MutedKeyword(keyword: Option[String]) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/SafetyLabel.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/SafetyLabel.scala new file mode 100644 index 000000000..06eb2492d --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/SafetyLabel.scala @@ -0,0 +1,90 @@ +package com.twitter.visibility.models + +import com.twitter.spam.rtf.{thriftscala => s} +import com.twitter.visibility.safety_label_store.{thriftscala => store} + +case class SafetyLabel( + score: Option[Double] = None, + applicableUsers: Set[Long] = Set.empty, + source: Option[LabelSource] = None, + modelMetadata: Option[TweetModelMetadata] = None, + createdAtMsec: Option[Long] = None, + expiresAtMsec: Option[Long] = None, + labelMetadata: Option[SafetyLabelMetadata] = None, + applicableCountries: Option[Seq[String]] = None) + +object SafetyLabel { + def fromThrift(safetyLabel: s.SafetyLabel): SafetyLabel = { + SafetyLabel( + score = safetyLabel.score, + applicableUsers = safetyLabel.applicableUsers + .map { perspectivalUsers => + (perspectivalUsers map { + _.userId + }).toSet + }.getOrElse(Set.empty), + source = safetyLabel.source.flatMap(LabelSource.fromString), + modelMetadata = safetyLabel.modelMetadata.flatMap(TweetModelMetadata.fromThrift), + createdAtMsec = safetyLabel.createdAtMsec, + expiresAtMsec = safetyLabel.expiresAtMsec, + labelMetadata = safetyLabel.labelMetadata.map(SafetyLabelMetadata.fromThrift(_)), + applicableCountries = safetyLabel.applicableCountries + ) + } + + def toThrift(safetyLabel: SafetyLabel): s.SafetyLabel = { + s.SafetyLabel( + score = safetyLabel.score, + applicableUsers = if (safetyLabel.applicableUsers.nonEmpty) { + Some(safetyLabel.applicableUsers.toSeq.map { + s.PerspectivalUser(_) + }) + } else { + None + }, + source = safetyLabel.source.map(_.name), + modelMetadata = safetyLabel.modelMetadata.map(TweetModelMetadata.toThrift), + createdAtMsec = safetyLabel.createdAtMsec, + expiresAtMsec = safetyLabel.expiresAtMsec, + labelMetadata = safetyLabel.labelMetadata.map(_.toThrift), + applicableCountries = safetyLabel.applicableCountries + ) + } +} + +trait SafetyLabelWithType[EntitySafetyLabelType <: SafetyLabelType] { + val safetyLabelType: EntitySafetyLabelType + val safetyLabel: SafetyLabel +} + +case class MediaSafetyLabel( + override val safetyLabelType: MediaSafetyLabelType, + override val safetyLabel: SafetyLabel) + extends SafetyLabelWithType[MediaSafetyLabelType] { + + def fromThrift( + thriftType: store.MediaSafetyLabelType, + thriftLabel: s.SafetyLabel + ): MediaSafetyLabel = { + MediaSafetyLabel( + MediaSafetyLabelType.fromThrift(thriftType), + SafetyLabel.fromThrift(thriftLabel) + ) + } +} + +case class SpaceSafetyLabel( + override val safetyLabelType: SpaceSafetyLabelType, + override val safetyLabel: SafetyLabel) + extends SafetyLabelWithType[SpaceSafetyLabelType] { + + def fromThrift( + thriftType: store.SpaceSafetyLabelType, + thriftLabel: s.SafetyLabel + ): SpaceSafetyLabel = { + SpaceSafetyLabel( + SpaceSafetyLabelType.fromThrift(thriftType), + SafetyLabel.fromThrift(thriftLabel) + ) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/SafetyLabelMetadata.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/SafetyLabelMetadata.scala new file mode 100644 index 000000000..b35986aba --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/SafetyLabelMetadata.scala @@ -0,0 +1,25 @@ +package com.twitter.visibility.models + +import com.twitter.guano.commons.thriftscala.PolicyInViolation +import com.twitter.spam.rtf.{thriftscala => s} + +case class SafetyLabelMetadata( + policyInViolation: Option[PolicyInViolation] = None, + policyUrl: Option[String] = None) { + + def toThrift: s.SafetyLabelMetadata = { + s.SafetyLabelMetadata( + policyInViolation, + policyUrl + ) + } +} + +object SafetyLabelMetadata { + def fromThrift(metadata: s.SafetyLabelMetadata): SafetyLabelMetadata = { + SafetyLabelMetadata( + metadata.policyInViolation, + metadata.policyUrl + ) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/SafetyLabelType.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/SafetyLabelType.scala new file mode 100644 index 000000000..d79bbd153 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/SafetyLabelType.scala @@ -0,0 +1,7 @@ +package com.twitter.visibility.models + +trait SafetyLabelType { + val DeprecatedEnumValue: Short = -1 + val UnknownEnumValue: Short = -2 + val StratoOnlyLabelEnumValue: Short = -3 +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/SafetyLevel.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/SafetyLevel.scala new file mode 100644 index 000000000..9042b9328 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/SafetyLevel.scala @@ -0,0 +1,851 @@ +package com.twitter.visibility.models + +import com.twitter.spam.rtf.thriftscala.{SafetyLevel => ThriftSafetyLevel} +import com.twitter.visibility.configapi.params.SafetyLevelParam +import com.twitter.visibility.configapi.params.SafetyLevelParams._ + +sealed trait SafetyLevel { + val name: String = this.getClass.getSimpleName.dropRight(1) + def enabledParam: SafetyLevelParam +} + +object SafetyLevel { + private lazy val nameToSafetyLevelMap: Map[String, SafetyLevel] = + SafetyLevel.List.map(s => s.name.toLowerCase -> s).toMap + def fromName(name: String): Option[SafetyLevel] = { + nameToSafetyLevelMap.get(name.toLowerCase) + } + + private val DeprecatedEnumValue = -1 + + private lazy val thriftToModelMap: Map[ThriftSafetyLevel, SafetyLevel] = Map( + ThriftSafetyLevel.AccessInternalPromotedContent -> AccessInternalPromotedContent, + ThriftSafetyLevel.AdsBusinessSettings -> AdsBusinessSettings, + ThriftSafetyLevel.AdsCampaign -> AdsCampaign, + ThriftSafetyLevel.AdsManager -> AdsManager, + ThriftSafetyLevel.AdsReportingDashboard -> AdsReportingDashboard, + ThriftSafetyLevel.AllSubscribedLists -> AllSubscribedLists, + ThriftSafetyLevel.Appeals -> Appeals, + ThriftSafetyLevel.ArticleTweetTimeline -> ArticleTweetTimeline, + ThriftSafetyLevel.BaseQig -> BaseQig, + ThriftSafetyLevel.BirdwatchNoteAuthor -> BirdwatchNoteAuthor, + ThriftSafetyLevel.BirdwatchNoteTweetsTimeline -> BirdwatchNoteTweetsTimeline, + ThriftSafetyLevel.BirdwatchNeedsYourHelpNotifications -> BirdwatchNeedsYourHelpNotifications, + ThriftSafetyLevel.BlockMuteUsersTimeline -> BlockMuteUsersTimeline, + ThriftSafetyLevel.BrandSafety -> BrandSafety, + ThriftSafetyLevel.CardPollVoting -> CardPollVoting, + ThriftSafetyLevel.CardsService -> CardsService, + ThriftSafetyLevel.Communities -> Communities, + ThriftSafetyLevel.ContentControlToolInstall -> ContentControlToolInstall, + ThriftSafetyLevel.ConversationFocalPrehydration -> ConversationFocalPrehydration, + ThriftSafetyLevel.ConversationFocalTweet -> ConversationFocalTweet, + ThriftSafetyLevel.ConversationInjectedTweet -> ConversationInjectedTweet, + ThriftSafetyLevel.ConversationReply -> ConversationReply, + ThriftSafetyLevel.CuratedTrendsRepresentativeTweet -> CuratedTrendsRepresentativeTweet, + ThriftSafetyLevel.CurationPolicyViolations -> CurationPolicyViolations, + ThriftSafetyLevel.DevPlatformGetListTweets -> DevPlatformGetListTweets, + ThriftSafetyLevel.DesFollowingAndFollowersUserList -> DesFollowingAndFollowersUserList, + ThriftSafetyLevel.DesHomeTimeline -> DesHomeTimeline, + ThriftSafetyLevel.DesQuoteTweetTimeline -> DesQuoteTweetTimeline, + ThriftSafetyLevel.DesRealtime -> DesRealtime, + ThriftSafetyLevel.DesRealtimeSpamEnrichment -> DesRealtimeSpamEnrichment, + ThriftSafetyLevel.DesRealtimeTweetFilter -> DesRealtimeTweetFilter, + ThriftSafetyLevel.DesRetweetingUsers -> DesRetweetingUsers, + ThriftSafetyLevel.DesTweetDetail -> DesTweetDetail, + ThriftSafetyLevel.DesTweetLikingUsers -> DesTweetLikingUsers, + ThriftSafetyLevel.DesUserBookmarks -> DesUserBookmarks, + ThriftSafetyLevel.DesUserLikedTweets -> DesUserLikedTweets, + ThriftSafetyLevel.DesUserMentions -> DesUserMentions, + ThriftSafetyLevel.DesUserTweets -> DesUserTweets, + ThriftSafetyLevel.DevPlatformComplianceStream -> DevPlatformComplianceStream, + ThriftSafetyLevel.DirectMessages -> DirectMessages, + ThriftSafetyLevel.DirectMessagesConversationList -> DirectMessagesConversationList, + ThriftSafetyLevel.DirectMessagesConversationTimeline -> DirectMessagesConversationTimeline, + ThriftSafetyLevel.DirectMessagesInbox -> DirectMessagesInbox, + ThriftSafetyLevel.DirectMessagesMutedUsers -> DirectMessagesMutedUsers, + ThriftSafetyLevel.DirectMessagesPinned -> DirectMessagesPinned, + ThriftSafetyLevel.DirectMessagesSearch -> DirectMessagesSearch, + ThriftSafetyLevel.EditHistoryTimeline -> EditHistoryTimeline, + ThriftSafetyLevel.ElevatedQuoteTweetTimeline -> ElevatedQuoteTweetTimeline, + ThriftSafetyLevel.EmbeddedTweet -> EmbeddedTweet, + ThriftSafetyLevel.EmbedsPublicInterestNotice -> EmbedsPublicInterestNotice, + ThriftSafetyLevel.EmbedTweetMarkup -> EmbedTweetMarkup, + ThriftSafetyLevel.ExploreRecommendations -> ExploreRecommendations, + ThriftSafetyLevel.WritePathLimitedActionsEnforcement -> WritePathLimitedActionsEnforcement, + ThriftSafetyLevel.FilterAll -> FilterAll, + ThriftSafetyLevel.FilterAllPlaceholder -> FilterAllPlaceholder, + ThriftSafetyLevel.FilterDefault -> FilterDefault, + ThriftSafetyLevel.FilterNone -> FilterNone, + ThriftSafetyLevel.FollowedTopicsTimeline -> FollowedTopicsTimeline, + ThriftSafetyLevel.FollowerConnections -> FollowerConnections, + ThriftSafetyLevel.FollowingAndFollowersUserList -> FollowingAndFollowersUserList, + ThriftSafetyLevel.ForDevelopmentOnly -> ForDevelopmentOnly, + ThriftSafetyLevel.FriendsFollowingList -> FriendsFollowingList, + ThriftSafetyLevel.GraphqlDefault -> GraphqlDefault, + ThriftSafetyLevel.HumanizationNudge -> HumanizationNudge, + ThriftSafetyLevel.KitchenSinkDevelopment -> KitchenSinkDevelopment, + ThriftSafetyLevel.ListHeader -> ListHeader, + ThriftSafetyLevel.ListMemberships -> ListMemberships, + ThriftSafetyLevel.ListOwnerships -> ListOwnerships, + ThriftSafetyLevel.ListRecommendations -> ListRecommendations, + ThriftSafetyLevel.ListSearch -> ListSearch, + ThriftSafetyLevel.ListSubscriptions -> ListSubscriptions, + ThriftSafetyLevel.LivePipelineEngagementCounts -> LivePipelineEngagementCounts, + ThriftSafetyLevel.LiveVideoTimeline -> LiveVideoTimeline, + ThriftSafetyLevel.MagicRecs -> MagicRecs, + ThriftSafetyLevel.MagicRecsV2 -> MagicRecsV2, + ThriftSafetyLevel.MagicRecsAggressive -> MagicRecsAggressive, + ThriftSafetyLevel.MagicRecsAggressiveV2 -> MagicRecsAggressiveV2, + ThriftSafetyLevel.Minimal -> Minimal, + ThriftSafetyLevel.ModeratedTweetsTimeline -> ModeratedTweetsTimeline, + ThriftSafetyLevel.Moments -> Moments, + ThriftSafetyLevel.NearbyTimeline -> NearbyTimeline, + ThriftSafetyLevel.NewUserExperience -> NewUserExperience, + ThriftSafetyLevel.NotificationsIbis -> NotificationsIbis, + ThriftSafetyLevel.NotificationsPlatform -> NotificationsPlatform, + ThriftSafetyLevel.NotificationsPlatformPush -> NotificationsPlatformPush, + ThriftSafetyLevel.NotificationsQig -> NotificationsQig, + ThriftSafetyLevel.NotificationsRead -> NotificationsRead, + ThriftSafetyLevel.NotificationsTimelineDeviceFollow -> NotificationsTimelineDeviceFollow, + ThriftSafetyLevel.NotificationsWrite -> NotificationsWrite, + ThriftSafetyLevel.NotificationsWriterTweetHydrator -> NotificationsWriterTweetHydrator, + ThriftSafetyLevel.NotificationsWriterV2 -> NotificationsWriterV2, + ThriftSafetyLevel.ProfileMixerMedia -> ProfileMixerMedia, + ThriftSafetyLevel.ProfileMixerFavorites -> ProfileMixerFavorites, + ThriftSafetyLevel.QuickPromoteTweetEligibility -> QuickPromoteTweetEligibility, + ThriftSafetyLevel.QuoteTweetTimeline -> QuoteTweetTimeline, + ThriftSafetyLevel.QuotedTweetRules -> QuotedTweetRules, + ThriftSafetyLevel.Recommendations -> Recommendations, + ThriftSafetyLevel.RecosVideo -> RecosVideo, + ThriftSafetyLevel.RecosWritePath -> RecosWritePath, + ThriftSafetyLevel.RepliesGrouping -> RepliesGrouping, + ThriftSafetyLevel.ReportCenter -> ReportCenter, + ThriftSafetyLevel.ReturningUserExperience -> ReturningUserExperience, + ThriftSafetyLevel.ReturningUserExperienceFocalTweet -> ReturningUserExperienceFocalTweet, + ThriftSafetyLevel.Revenue -> Revenue, + ThriftSafetyLevel.RitoActionedTweetTimeline -> RitoActionedTweetTimeline, + ThriftSafetyLevel.SafeSearchMinimal -> SafeSearchMinimal, + ThriftSafetyLevel.SafeSearchStrict -> SafeSearchStrict, + ThriftSafetyLevel.SearchHydration -> SearchHydration, + ThriftSafetyLevel.SearchLatest -> SearchLatest, + ThriftSafetyLevel.SearchTop -> SearchTop, + ThriftSafetyLevel.SearchTopQig -> SearchTopQig, + ThriftSafetyLevel.SearchMixerSrpMinimal -> SearchMixerSrpMinimal, + ThriftSafetyLevel.SearchMixerSrpStrict -> SearchMixerSrpStrict, + ThriftSafetyLevel.SearchPeopleSrp -> SearchPeopleSrp, + ThriftSafetyLevel.SearchPeopleTypeahead -> SearchPeopleTypeahead, + ThriftSafetyLevel.SearchPhoto -> SearchPhoto, + ThriftSafetyLevel.SearchTrendTakeoverPromotedTweet -> SearchTrendTakeoverPromotedTweet, + ThriftSafetyLevel.SearchVideo -> SearchVideo, + ThriftSafetyLevel.SearchBlenderUserRules -> SearchBlenderUserRules, + ThriftSafetyLevel.SearchLatestUserRules -> SearchLatestUserRules, + ThriftSafetyLevel.ShoppingManagerSpyMode -> ShoppingManagerSpyMode, + ThriftSafetyLevel.SignalsReactions -> SignalsReactions, + ThriftSafetyLevel.SignalsTweetReactingUsers -> SignalsTweetReactingUsers, + ThriftSafetyLevel.SocialProof -> SocialProof, + ThriftSafetyLevel.SoftInterventionPivot -> SoftInterventionPivot, + ThriftSafetyLevel.SpaceFleetline -> SpaceFleetline, + ThriftSafetyLevel.SpaceHomeTimelineUpranking -> SpaceHomeTimelineUpranking, + ThriftSafetyLevel.SpaceJoinScreen -> SpaceJoinScreen, + ThriftSafetyLevel.SpaceNotifications -> SpaceNotifications, + ThriftSafetyLevel.Spaces -> Spaces, + ThriftSafetyLevel.SpacesParticipants -> SpacesParticipants, + ThriftSafetyLevel.SpacesSellerApplicationStatus -> SpacesSellerApplicationStatus, + ThriftSafetyLevel.SpacesSharing -> SpacesSharing, + ThriftSafetyLevel.SpaceTweetAvatarHomeTimeline -> SpaceTweetAvatarHomeTimeline, + ThriftSafetyLevel.StickersTimeline -> StickersTimeline, + ThriftSafetyLevel.StratoExtLimitedEngagements -> StratoExtLimitedEngagements, + ThriftSafetyLevel.StreamServices -> StreamServices, + ThriftSafetyLevel.SuperFollowerConnections -> SuperFollowerConnections, + ThriftSafetyLevel.SuperLike -> SuperLike, + ThriftSafetyLevel.Test -> Test, + ThriftSafetyLevel.TimelineBookmark -> TimelineBookmark, + ThriftSafetyLevel.TimelineContentControls -> TimelineContentControls, + ThriftSafetyLevel.TimelineConversations -> TimelineConversations, + ThriftSafetyLevel.TimelineConversationsDownranking -> TimelineConversationsDownranking, + ThriftSafetyLevel.TimelineConversationsDownrankingMinimal -> TimelineConversationsDownrankingMinimal, + ThriftSafetyLevel.TimelineFavorites -> TimelineFavorites, + ThriftSafetyLevel.TimelineFavoritesSelfView -> TimelineFavoritesSelfView, + ThriftSafetyLevel.TimelineFocalTweet -> TimelineFocalTweet, + ThriftSafetyLevel.TimelineFollowingActivity -> TimelineFollowingActivity, + ThriftSafetyLevel.TimelineHome -> TimelineHome, + ThriftSafetyLevel.TimelineHomeCommunities -> TimelineHomeCommunities, + ThriftSafetyLevel.TimelineHomeHydration -> TimelineHomeHydration, + ThriftSafetyLevel.TimelineHomeLatest -> TimelineHomeLatest, + ThriftSafetyLevel.TimelineHomePromotedHydration -> TimelineHomePromotedHydration, + ThriftSafetyLevel.TimelineHomeRecommendations -> TimelineHomeRecommendations, + ThriftSafetyLevel.TimelineHomeTopicFollowRecommendations -> TimelineHomeTopicFollowRecommendations, + ThriftSafetyLevel.TimelineScorer -> TimelineScorer, + ThriftSafetyLevel.TimelineInjection -> TimelineInjection, + ThriftSafetyLevel.TimelineLikedBy -> TimelineLikedBy, + ThriftSafetyLevel.TimelineLists -> TimelineLists, + ThriftSafetyLevel.TimelineMedia -> TimelineMedia, + ThriftSafetyLevel.TimelineMentions -> TimelineMentions, + ThriftSafetyLevel.TimelineModeratedTweetsHydration -> TimelineModeratedTweetsHydration, + ThriftSafetyLevel.TimelineProfile -> TimelineProfile, + ThriftSafetyLevel.TimelineProfileAll -> TimelineProfileAll, + ThriftSafetyLevel.TimelineProfileSpaces -> TimelineProfileSpaces, + ThriftSafetyLevel.TimelineProfileSuperFollows -> TimelineProfileSuperFollows, + ThriftSafetyLevel.TimelineReactiveBlending -> TimelineReactiveBlending, + ThriftSafetyLevel.TimelineRetweetedBy -> TimelineRetweetedBy, + ThriftSafetyLevel.TimelineSuperLikedBy -> TimelineSuperLikedBy, + ThriftSafetyLevel.Tombstoning -> Tombstoning, + ThriftSafetyLevel.TopicRecommendations -> TopicRecommendations, + ThriftSafetyLevel.TopicsLandingPageTopicRecommendations -> TopicsLandingPageTopicRecommendations, + ThriftSafetyLevel.TrendsRepresentativeTweet -> TrendsRepresentativeTweet, + ThriftSafetyLevel.TrustedFriendsUserList -> TrustedFriendsUserList, + ThriftSafetyLevel.GryphonDecksAndColumns -> GryphonDecksAndColumns, + ThriftSafetyLevel.TweetDetail -> TweetDetail, + ThriftSafetyLevel.TweetDetailNonToo -> TweetDetailNonToo, + ThriftSafetyLevel.TweetDetailWithInjectionsHydration -> TweetDetailWithInjectionsHydration, + ThriftSafetyLevel.TweetEngagers -> TweetEngagers, + ThriftSafetyLevel.TweetReplyNudge -> TweetReplyNudge, + ThriftSafetyLevel.TweetScopedTimeline -> TweetScopedTimeline, + ThriftSafetyLevel.TweetWritesApi -> TweetWritesApi, + ThriftSafetyLevel.TwitterArticleCompose -> TwitterArticleCompose, + ThriftSafetyLevel.TwitterArticleProfileTab -> TwitterArticleProfileTab, + ThriftSafetyLevel.TwitterArticleRead -> TwitterArticleRead, + ThriftSafetyLevel.UserProfileHeader -> UserProfileHeader, + ThriftSafetyLevel.UserMilestoneRecommendation -> UserMilestoneRecommendation, + ThriftSafetyLevel.UserScopedTimeline -> UserScopedTimeline, + ThriftSafetyLevel.UserSearchSrp -> UserSearchSrp, + ThriftSafetyLevel.UserSearchTypeahead -> UserSearchTypeahead, + ThriftSafetyLevel.UserSelfViewOnly -> UserSelfViewOnly, + ThriftSafetyLevel.UserSettings -> UserSettings, + ThriftSafetyLevel.VideoAds -> VideoAds, + ThriftSafetyLevel.ZipbirdConsumerArchives -> ZipbirdConsumerArchives, + ThriftSafetyLevel.TweetAward -> TweetAward, + ) + + private lazy val modelToThriftMap: Map[SafetyLevel, ThriftSafetyLevel] = + for ((k, v) <- thriftToModelMap) yield (v, k) + + case object AdsBusinessSettings extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableAdsBusinessSettingsSafetyLevelParam + } + case object AdsCampaign extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableAdsCampaignSafetyLevelParam + } + case object AdsManager extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableAdsManagerSafetyLevelParam + } + case object AdsReportingDashboard extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableAdsReportingDashboardSafetyLevelParam + } + case object AllSubscribedLists extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableAllSubscribedListsSafetyLevelParam + } + case object Appeals extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableAppealsSafetyLevelParam + } + case object ArticleTweetTimeline extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableArticleTweetTimelineSafetyLevelParam + } + case object BaseQig extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableBaseQigSafetyLevelParam + } + case object BirdwatchNoteAuthor extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableBirdwatchNoteAuthorSafetyLevel + } + case object BirdwatchNoteTweetsTimeline extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableBirdwatchNoteTweetsTimelineSafetyLevel + } + case object BirdwatchNeedsYourHelpNotifications extends SafetyLevel { + override val enabledParam: SafetyLevelParam = + EnableBirdwatchNeedsYourHelpNotificationsSafetyLevel + } + case object BlockMuteUsersTimeline extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableBlockMuteUsersTimelineSafetyLevelParam + } + case object BrandSafety extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableBrandSafetySafetyLevelParam + } + case object CardPollVoting extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableCardPollVotingSafetyLevelParam + } + case object CardsService extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableCardsServiceSafetyLevelParam + } + case object Communities extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableCommunitiesSafetyLevelParam + } + case object ContentControlToolInstall extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableContentControlToolInstallSafetyLevelParam + } + case object ConversationFocalPrehydration extends SafetyLevel { + override val enabledParam: SafetyLevelParam = + EnableConversationFocalPrehydrationSafetyLevelParam + } + case object ConversationFocalTweet extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableConversationFocalTweetSafetyLevelParam + } + case object ConversationInjectedTweet extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableConversationInjectedTweetSafetyLevelParam + } + case object ConversationReply extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableConversationReplySafetyLevelParam + } + case object AccessInternalPromotedContent extends SafetyLevel { + override val enabledParam: SafetyLevelParam = + EnableAccessInternalPromotedContentSafetyLevelParam + } + case object CuratedTrendsRepresentativeTweet extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableCuratedTrendsRepresentativeTweet + } + case object CurationPolicyViolations extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableCurationPolicyViolations + } + case object DevPlatformGetListTweets extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableDevPlatformGetListTweetsSafetyLevelParam + } + case object DesFollowingAndFollowersUserList extends SafetyLevel { + override val enabledParam: SafetyLevelParam = + EnableDESFollowingAndFollowersUserListSafetyLevelParam + } + case object DesHomeTimeline extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableDESHomeTimelineSafetyLevelParam + } + case object DesQuoteTweetTimeline extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableDESQuoteTweetTimelineSafetyLevelParam + } + case object DesRealtime extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableDESRealtimeSafetyLevelParam + } + case object DesRealtimeSpamEnrichment extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableDESRealtimeSpamEnrichmentSafetyLevelParam + } + case object DesRealtimeTweetFilter extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableDESRealtimeTweetFilterSafetyLevelParam + } + case object DesRetweetingUsers extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableDESRetweetingUsersSafetyLevelParam + } + case object DesTweetDetail extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableDesTweetDetailSafetyLevelParam + } + case object DesTweetLikingUsers extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableDESTweetLikingUsersSafetyLevelParam + } + case object DesUserBookmarks extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableDESUserBookmarksSafetyLevelParam + } + case object DesUserLikedTweets extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableDESUserLikedTweetSafetyLevelParam + } + case object DesUserMentions extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableDESUserMentionsSafetyLevelParam + } + case object DesUserTweets extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableDESUserTweetsSafetyLevelParam + } + case object DevPlatformComplianceStream extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableDevPlatformComplianceStreamSafetyLevelParam + } + case object DirectMessages extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableDirectMessagesSafetyLevelParam + } + case object DirectMessagesConversationList extends SafetyLevel { + override val enabledParam: SafetyLevelParam = + EnableDirectMessagesConversationListSafetyLevelParam + } + case object DirectMessagesConversationTimeline extends SafetyLevel { + override val enabledParam: SafetyLevelParam = + EnableDirectMessagesConversationTimelineSafetyLevelParam + } + case object DirectMessagesInbox extends SafetyLevel { + override val enabledParam: SafetyLevelParam = + EnableDirectMessagesInboxSafetyLevelParam + } + case object DirectMessagesMutedUsers extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableDirectMessagesMutedUsersSafetyLevelParam + } + case object DirectMessagesPinned extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableDirectMessagesPinnedSafetyLevelParam + } + case object DirectMessagesSearch extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableDirectMessagesSearchSafetyLevelParam + } + case object EditHistoryTimeline extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableEditHistoryTimelineSafetyLevelParam + } + case object ElevatedQuoteTweetTimeline extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableElevatedQuoteTweetTimelineSafetyLevelParam + } + case object EmbeddedTweet extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableEmbeddedTweetSafetyLevelParam + } + case object EmbedsPublicInterestNotice extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableEmbedsPublicInterestNoticeSafetyLevelParam + } + case object EmbedTweetMarkup extends SafetyLevel { + override def enabledParam: SafetyLevelParam = EnableEmbedTweetMarkupSafetyLevelParam + } + case object WritePathLimitedActionsEnforcement extends SafetyLevel { + override def enabledParam: SafetyLevelParam = + EnableWritePathLimitedActionsEnforcementSafetyLevelParam + } + case object FilterNone extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableFilterNoneSafetyLevelParam + } + case object FilterAll extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableFilterAllSafetyLevelParam + } + case object FilterAllPlaceholder extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableFilterDefaultSafetyLevelParam + } + case object FilterDefault extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableFilterDefaultSafetyLevelParam + } + case object FollowedTopicsTimeline extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableFollowedTopicsTimelineSafetyLevelParam + } + case object FollowerConnections extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableFollowerConnectionsSafetyLevelParam + } + case object FollowingAndFollowersUserList extends SafetyLevel { + override val enabledParam: SafetyLevelParam = + EnableFollowingAndFollowersUserListSafetyLevelParam + } + case object ForDevelopmentOnly extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableForDevelopmentOnlySafetyLevelParam + } + case object FriendsFollowingList extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableFriendsFollowingListSafetyLevelParam + } + case object GraphqlDefault extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableGraphqlDefaultSafetyLevelParam + } + case object GryphonDecksAndColumns extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableGryphonDecksAndColumnsSafetyLevelParam + } + case object HumanizationNudge extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableHumanizationNudgeSafetyLevelParam + } + case object KitchenSinkDevelopment extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableKitchenSinkDevelopmentSafetyLevelParam + } + case object ListHeader extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableListHeaderSafetyLevelParam + } + case object ListMemberships extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableListMembershipsSafetyLevelParam + } + case object ListOwnerships extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableListOwnershipsSafetyLevelParam + } + case object ListRecommendations extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableListRecommendationsSafetyLevelParam + } + case object ListSearch extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableListSearchSafetyLevelParam + } + case object ListSubscriptions extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableListSubscriptionsSafetyLevelParam + } + case object LivePipelineEngagementCounts extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableLivePipelineEngagementCountsSafetyLevelParam + } + case object LiveVideoTimeline extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableLiveVideoTimelineSafetyLevelParam + } + case object MagicRecs extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableMagicRecsSafetyLevelParam + } + case object MagicRecsAggressive extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableMagicRecsAggressiveSafetyLevelParam + } + case object MagicRecsAggressiveV2 extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableMagicRecsAggressiveV2SafetyLevelParam + } + case object MagicRecsV2 extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableMagicRecsV2SafetyLevelParam + } + case object Minimal extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableMinimalSafetyLevelParam + } + case object ModeratedTweetsTimeline extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableModeratedTweetsTimelineSafetyLevelParam + } + case object Moments extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableMomentsSafetyLevelParam + } + case object NearbyTimeline extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableNearbySafetyLevelParam + } + case object NewUserExperience extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableNewUserExperienceSafetyLevelParam + } + case object NotificationsIbis extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableNotificationsIbisSafetyLevelParam + } + case object NotificationsPlatform extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableNotificationsPlatformSafetyLevelParam + } + case object NotificationsPlatformPush extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableNotificationsPlatformPushSafetyLevelParam + } + case object NotificationsQig extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableNotificationsQigSafetyLevelParam + } + case object NotificationsRead extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableNotificationsReadSafetyLevelParam + } + case object NotificationsTimelineDeviceFollow extends SafetyLevel { + override val enabledParam: SafetyLevelParam = + EnableNotificationsTimelineDeviceFollowSafetyLevelParam + } + case object NotificationsWrite extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableNotificationsWriteSafetyLevelParam + } + case object NotificationsWriterV2 extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableNotificationsWriterV2SafetyLevelParam + } + case object NotificationsWriterTweetHydrator extends SafetyLevel { + override val enabledParam: SafetyLevelParam = + EnableNotificationsWriterTweetHydratorSafetyLevelParam + } + case object ProfileMixerMedia extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableProfileMixerMediaSafetyLevelParam + } + case object ProfileMixerFavorites extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableProfileMixerFavoritesSafetyLevelParam + } + case object QuickPromoteTweetEligibility extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableQuickPromoteTweetEligibilitySafetyLevelParam + } + case object QuoteTweetTimeline extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableQuoteTweetTimelineSafetyLevelParam + } + case object QuotedTweetRules extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableQuotedTweetRulesParam + } + case object Recommendations extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableRecommendationsSafetyLevelParam + } + case object RecosVideo extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableRecosVideoSafetyLevelParam + } + case object RecosWritePath extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableRecosWritePathSafetyLevelParam + } + case object RepliesGrouping extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableRepliesGroupingSafetyLevelParam + } + case object ReportCenter extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableReportCenterSafetyLevelParam + } + case object ReturningUserExperience extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableReturningUserExperienceSafetyLevelParam + } + case object ReturningUserExperienceFocalTweet extends SafetyLevel { + override val enabledParam: SafetyLevelParam = + EnableReturningUserExperienceFocalTweetSafetyLevelParam + } + case object Revenue extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableRevenueSafetyLevelParam + } + case object RitoActionedTweetTimeline extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableRitoActionedTweetTimelineParam + } + case object SafeSearchMinimal extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSafeSearchMinimalSafetyLevelParam + } + case object SafeSearchStrict extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSafeSearchStrictSafetyLevelParam + } + case object SearchHydration extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSearchHydrationSafetyLevelParam + } + case object SearchLatest extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSearchLatestSafetyLevelParam + } + case object SearchMixerSrpMinimal extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSearchMixerSrpMinimalSafetyLevelParam + } + case object SearchMixerSrpStrict extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSearchMixerSrpStrictSafetyLevelParam + } + case object SearchPeopleSrp extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSearchPeopleSearchResultPageSafetyLevelParam + } + case object SearchPeopleTypeahead extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSearchPeopleTypeaheadSafetyLevelParam + } + case object SearchPhoto extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSearchPhotoSafetyLevelParam + } + case object ShoppingManagerSpyMode extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableShoppingManagerSpyModeSafetyLevelParam + } + case object StratoExtLimitedEngagements extends SafetyLevel { + override val enabledParam: SafetyLevelParam = + EnableStratoExtLimitedEngagementsSafetyLevelParam + } + case object SearchTop extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSearchTopSafetyLevelParam + } + case object SearchTopQig extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSearchTopQigSafetyLevelParam + } + case object SearchTrendTakeoverPromotedTweet extends SafetyLevel { + override val enabledParam: SafetyLevelParam = SearchTrendTakeoverPromotedTweetSafetyLevelParam + } + case object SearchVideo extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSearchVideoSafetyLevelParam + } + case object SearchBlenderUserRules extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSearchBlenderUserRulesSafetyLevelParam + } + case object SearchLatestUserRules extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSearchLatestUserRulesSafetyLevelParam + } + case object SignalsReactions extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSignalsReactionsSafetyLevelParam + } + case object SignalsTweetReactingUsers extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSignalsTweetReactingUsersSafetyLevelParam + } + case object SocialProof extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSocialProofSafetyLevelParam + } + case object SoftInterventionPivot extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSoftInterventionPivotSafetyLevelParam + } + case object SpaceFleetline extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSpaceFleetlineSafetyLevelParam + } + case object SpaceHomeTimelineUpranking extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSpaceHomeTimelineUprankingSafetyLevelParam + } + case object SpaceJoinScreen extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSpaceJoinScreenSafetyLevelParam + } + case object SpaceNotifications extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSpaceNotificationSafetyLevelParam + } + case object Spaces extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSpacesSafetyLevelParam + } + case object SpacesParticipants extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSpacesParticipantsSafetyLevelParam + } + case object SpacesSellerApplicationStatus extends SafetyLevel { + override val enabledParam: SafetyLevelParam = + EnableSpacesSellerApplicationStatusSafetyLevelParam + } + case object SpacesSharing extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSpacesSharingSafetyLevelParam + } + case object SpaceTweetAvatarHomeTimeline extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSpaceTweetAvatarHomeTimelineSafetyLevelParam + } + case object StickersTimeline extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableStickersTimelineSafetyLevelParam + } + case object StreamServices extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableStreamServicesSafetyLevelParam + } + case object SuperFollowerConnections extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSuperFollowerConnectionsSafetyLevelParam + } + case object SuperLike extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableSuperLikeSafetyLevelParam + } + case object Test extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTestSafetyLevelParam + } + case object TimelineConversations extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineConversationsSafetyLevelParam + } + case object TimelineConversationsDownranking extends SafetyLevel { + override val enabledParam: SafetyLevelParam = + EnableTimelineConversationsDownrankingSafetyLevelParam + } + case object TimelineConversationsDownrankingMinimal extends SafetyLevel { + override val enabledParam: SafetyLevelParam = + EnableTimelineConversationsDownrankingMinimalSafetyLevelParam + } + case object TimelineFollowingActivity extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineFollowingActivitySafetyLevelParam + } + case object TimelineHome extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineHomeSafetyLevelParam + } + case object TimelineHomeCommunities extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineHomeCommunitiesSafetyLevelParam + } + case object TimelineHomeHydration extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineHomeHydrationSafetyLevelParam + } + case object TimelineHomePromotedHydration extends SafetyLevel { + override val enabledParam: SafetyLevelParam = + EnableTimelineHomePromotedHydrationSafetyLevelParam + } + case object TimelineHomeRecommendations extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineHomeRecommendationsSafetyLevelParam + } + case object TimelineHomeTopicFollowRecommendations extends SafetyLevel { + override val enabledParam: SafetyLevelParam = + EnableTimelineHomeTopicFollowRecommendationsSafetyLevelParam + } + case object TimelineScorer extends SafetyLevel { + override val enabledParam: SafetyLevelParam = + EnableTimelineScorerSafetyLevelParam + } + case object TopicsLandingPageTopicRecommendations extends SafetyLevel { + override val enabledParam: SafetyLevelParam = + EnableTopicsLandingPageTopicRecommendationsSafetyLevelParam + } + case object ExploreRecommendations extends SafetyLevel { + override val enabledParam: SafetyLevelParam = + EnableExploreRecommendationsSafetyLevelParam + } + case object TimelineModeratedTweetsHydration extends SafetyLevel { + override val enabledParam: SafetyLevelParam = + EnableTimelineModeratedTweetsHydrationSafetyLevelParam + } + case object TimelineInjection extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineInjectionSafetyLevelParam + } + case object TimelineMentions extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineMentionsSafetyLevelParam + } + case object TimelineHomeLatest extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineHomeLatestSafetyLevelParam + } + case object TimelineLikedBy extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineLikedBySafetyLevelParam + } + case object TimelineRetweetedBy extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineRetweetedBySafetyLevelParam + } + case object TimelineSuperLikedBy extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineLikedBySafetyLevelParam + } + case object TimelineBookmark extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineBookmarkSafetyLevelParam + } + case object TimelineContentControls extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineContentControlsSafetyLevelParam + } + case object TimelineMedia extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineMediaSafetyLevelParam + } + case object TimelineReactiveBlending extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineReactiveBlendingSafetyLevelParam + } + case object TimelineFavorites extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineFavoritesSafetyLevelParam + } + case object TimelineFavoritesSelfView extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineFavoritesSelfViewSafetyLevelParam + } + case object TimelineLists extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineListsSafetyLevelParam + } + case object TimelineProfile extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineProfileSafetyLevelParam + } + case object TimelineProfileAll extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineProfileAllSafetyLevelParam + } + + case object TimelineProfileSpaces extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineProfileSpacesSafetyLevelParam + } + + case object TimelineProfileSuperFollows extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineProfileSuperFollowsSafetyLevelParam + } + case object TimelineFocalTweet extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTimelineFocalTweetSafetyLevelParam + } + case object Tombstoning extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTombstoningSafetyLevelParam + } + case object TopicRecommendations extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTopicRecommendationsSafetyLevelParam + } + case object TrendsRepresentativeTweet extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTrendsRepresentativeTweetSafetyLevelParam + } + case object TrustedFriendsUserList extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTrustedFriendsUserListSafetyLevelParam + } + case object TweetDetail extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTweetDetailSafetyLevelParam + } + case object TweetDetailNonToo extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTweetDetailNonTooSafetyLevelParam + } + case object TweetDetailWithInjectionsHydration extends SafetyLevel { + override val enabledParam: SafetyLevelParam = + EnableTweetDetailWithInjectionsHydrationSafetyLevelParam + } + case object TweetEngagers extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTweetEngagersSafetyLevelParam + } + case object TweetReplyNudge extends SafetyLevel { + override def enabledParam: SafetyLevelParam = EnableTweetReplyNudgeParam + } + case object TweetScopedTimeline extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTweetScopedTimelineSafetyLevelParam + } + case object TweetWritesApi extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTweetWritesApiSafetyLevelParam + } + case object TwitterArticleCompose extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTwitterArticleComposeSafetyLevelParam + } + case object TwitterArticleProfileTab extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTwitterArticleProfileTabSafetyLevelParam + } + case object TwitterArticleRead extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTwitterArticleReadSafetyLevelParam + } + case object UserProfileHeader extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableUserProfileHeaderSafetyLevelParam + } + case object UserMilestoneRecommendation extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableUserMilestoneRecommendationSafetyLevelParam + } + case object UserScopedTimeline extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableUserScopedTimelineSafetyLevelParam + } + case object UserSearchSrp extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableUserSearchSrpSafetyLevelParam + } + case object UserSearchTypeahead extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableUserSearchTypeaheadSafetyLevelParam + } + case object UserSelfViewOnly extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableUserSelfViewOnlySafetyLevelParam + } + case object UserSettings extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableUserSettingsSafetyLevelParam + } + case object VideoAds extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableVideoAdsSafetyLevelParam + } + case object ZipbirdConsumerArchives extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableZipbirdConsumerArchivesSafetyLevelParam + } + case object TweetAward extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableTweetAwardSafetyLevelParam + } + + case object DeprecatedSafetyLevel extends SafetyLevel { + override val enabledParam: SafetyLevelParam = EnableDeprecatedSafetyLevel + } + + + def fromThrift(safetyLevel: ThriftSafetyLevel): SafetyLevel = + thriftToModelMap.get(safetyLevel).getOrElse(DeprecatedSafetyLevel) + + def toThrift(safetyLevel: SafetyLevel): ThriftSafetyLevel = + modelToThriftMap + .get(safetyLevel).getOrElse(ThriftSafetyLevel.EnumUnknownSafetyLevel(DeprecatedEnumValue)) + + val List: Seq[SafetyLevel] = + ThriftSafetyLevel.list.map(fromThrift).filter(_ != DeprecatedSafetyLevel) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/SafetyLevelGroup.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/SafetyLevelGroup.scala new file mode 100644 index 000000000..e60daefd1 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/SafetyLevelGroup.scala @@ -0,0 +1,557 @@ +package com.twitter.visibility.models + +import com.twitter.visibility.models.SafetyLevel.AccessInternalPromotedContent +import com.twitter.visibility.models.SafetyLevel.AdsBusinessSettings +import com.twitter.visibility.models.SafetyLevel.AdsCampaign +import com.twitter.visibility.models.SafetyLevel.AdsManager +import com.twitter.visibility.models.SafetyLevel.AdsReportingDashboard +import com.twitter.visibility.models.SafetyLevel.AllSubscribedLists +import com.twitter.visibility.models.SafetyLevel.Appeals +import com.twitter.visibility.models.SafetyLevel.ArticleTweetTimeline +import com.twitter.visibility.models.SafetyLevel.BaseQig +import com.twitter.visibility.models.SafetyLevel.BirdwatchNeedsYourHelpNotifications +import com.twitter.visibility.models.SafetyLevel.BirdwatchNoteAuthor +import com.twitter.visibility.models.SafetyLevel.BirdwatchNoteTweetsTimeline +import com.twitter.visibility.models.SafetyLevel.BlockMuteUsersTimeline +import com.twitter.visibility.models.SafetyLevel.BrandSafety +import com.twitter.visibility.models.SafetyLevel.CardPollVoting +import com.twitter.visibility.models.SafetyLevel.CardsService +import com.twitter.visibility.models.SafetyLevel.ContentControlToolInstall +import com.twitter.visibility.models.SafetyLevel.ConversationFocalPrehydration +import com.twitter.visibility.models.SafetyLevel.ConversationFocalTweet +import com.twitter.visibility.models.SafetyLevel.ConversationInjectedTweet +import com.twitter.visibility.models.SafetyLevel.ConversationReply +import com.twitter.visibility.models.SafetyLevel.CuratedTrendsRepresentativeTweet +import com.twitter.visibility.models.SafetyLevel.CurationPolicyViolations +import com.twitter.visibility.models.SafetyLevel.DesFollowingAndFollowersUserList +import com.twitter.visibility.models.SafetyLevel.DesHomeTimeline +import com.twitter.visibility.models.SafetyLevel.DesQuoteTweetTimeline +import com.twitter.visibility.models.SafetyLevel.DesRealtime +import com.twitter.visibility.models.SafetyLevel.DesRealtimeSpamEnrichment +import com.twitter.visibility.models.SafetyLevel.DesRealtimeTweetFilter +import com.twitter.visibility.models.SafetyLevel.DesRetweetingUsers +import com.twitter.visibility.models.SafetyLevel.DesTweetDetail +import com.twitter.visibility.models.SafetyLevel.DesTweetLikingUsers +import com.twitter.visibility.models.SafetyLevel.DesUserBookmarks +import com.twitter.visibility.models.SafetyLevel.DesUserLikedTweets +import com.twitter.visibility.models.SafetyLevel.DesUserMentions +import com.twitter.visibility.models.SafetyLevel.DesUserTweets +import com.twitter.visibility.models.SafetyLevel.DevPlatformComplianceStream +import com.twitter.visibility.models.SafetyLevel.DevPlatformGetListTweets +import com.twitter.visibility.models.SafetyLevel.DirectMessages +import com.twitter.visibility.models.SafetyLevel.DirectMessagesConversationList +import com.twitter.visibility.models.SafetyLevel.DirectMessagesConversationTimeline +import com.twitter.visibility.models.SafetyLevel.DirectMessagesInbox +import com.twitter.visibility.models.SafetyLevel.DirectMessagesMutedUsers +import com.twitter.visibility.models.SafetyLevel.DirectMessagesPinned +import com.twitter.visibility.models.SafetyLevel.DirectMessagesSearch +import com.twitter.visibility.models.SafetyLevel.EditHistoryTimeline +import com.twitter.visibility.models.SafetyLevel.ElevatedQuoteTweetTimeline +import com.twitter.visibility.models.SafetyLevel.EmbedTweetMarkup +import com.twitter.visibility.models.SafetyLevel.EmbeddedTweet +import com.twitter.visibility.models.SafetyLevel.EmbedsPublicInterestNotice +import com.twitter.visibility.models.SafetyLevel.ExploreRecommendations +import com.twitter.visibility.models.SafetyLevel.FilterAll +import com.twitter.visibility.models.SafetyLevel.FilterAllPlaceholder +import com.twitter.visibility.models.SafetyLevel.FilterDefault +import com.twitter.visibility.models.SafetyLevel.FilterNone +import com.twitter.visibility.models.SafetyLevel.FollowedTopicsTimeline +import com.twitter.visibility.models.SafetyLevel.FollowerConnections +import com.twitter.visibility.models.SafetyLevel.FollowingAndFollowersUserList +import com.twitter.visibility.models.SafetyLevel.ForDevelopmentOnly +import com.twitter.visibility.models.SafetyLevel.FriendsFollowingList +import com.twitter.visibility.models.SafetyLevel.GraphqlDefault +import com.twitter.visibility.models.SafetyLevel.GryphonDecksAndColumns +import com.twitter.visibility.models.SafetyLevel.HumanizationNudge +import com.twitter.visibility.models.SafetyLevel.KitchenSinkDevelopment +import com.twitter.visibility.models.SafetyLevel.ListHeader +import com.twitter.visibility.models.SafetyLevel.ListMemberships +import com.twitter.visibility.models.SafetyLevel.ListOwnerships +import com.twitter.visibility.models.SafetyLevel.ListRecommendations +import com.twitter.visibility.models.SafetyLevel.ListSearch +import com.twitter.visibility.models.SafetyLevel.ListSubscriptions +import com.twitter.visibility.models.SafetyLevel.LivePipelineEngagementCounts +import com.twitter.visibility.models.SafetyLevel.LiveVideoTimeline +import com.twitter.visibility.models.SafetyLevel.MagicRecs +import com.twitter.visibility.models.SafetyLevel.MagicRecsAggressive +import com.twitter.visibility.models.SafetyLevel.MagicRecsAggressiveV2 +import com.twitter.visibility.models.SafetyLevel.MagicRecsV2 +import com.twitter.visibility.models.SafetyLevel.Minimal +import com.twitter.visibility.models.SafetyLevel.ModeratedTweetsTimeline +import com.twitter.visibility.models.SafetyLevel.Moments +import com.twitter.visibility.models.SafetyLevel.NearbyTimeline +import com.twitter.visibility.models.SafetyLevel.NewUserExperience +import com.twitter.visibility.models.SafetyLevel.NotificationsIbis +import com.twitter.visibility.models.SafetyLevel.NotificationsPlatform +import com.twitter.visibility.models.SafetyLevel.NotificationsPlatformPush +import com.twitter.visibility.models.SafetyLevel.NotificationsQig +import com.twitter.visibility.models.SafetyLevel.NotificationsRead +import com.twitter.visibility.models.SafetyLevel.NotificationsTimelineDeviceFollow +import com.twitter.visibility.models.SafetyLevel.NotificationsWrite +import com.twitter.visibility.models.SafetyLevel.NotificationsWriterTweetHydrator +import com.twitter.visibility.models.SafetyLevel.NotificationsWriterV2 +import com.twitter.visibility.models.SafetyLevel.ProfileMixerFavorites +import com.twitter.visibility.models.SafetyLevel.ProfileMixerMedia +import com.twitter.visibility.models.SafetyLevel.QuickPromoteTweetEligibility +import com.twitter.visibility.models.SafetyLevel.QuoteTweetTimeline +import com.twitter.visibility.models.SafetyLevel.QuotedTweetRules +import com.twitter.visibility.models.SafetyLevel.RecosVideo +import com.twitter.visibility.models.SafetyLevel.RecosWritePath +import com.twitter.visibility.models.SafetyLevel.RepliesGrouping +import com.twitter.visibility.models.SafetyLevel.ReportCenter +import com.twitter.visibility.models.SafetyLevel.ReturningUserExperienceFocalTweet +import com.twitter.visibility.models.SafetyLevel.Revenue +import com.twitter.visibility.models.SafetyLevel.SafeSearchMinimal +import com.twitter.visibility.models.SafetyLevel.SafeSearchStrict +import com.twitter.visibility.models.SafetyLevel.SearchBlenderUserRules +import com.twitter.visibility.models.SafetyLevel.SearchHydration +import com.twitter.visibility.models.SafetyLevel.SearchLatest +import com.twitter.visibility.models.SafetyLevel.SearchLatestUserRules +import com.twitter.visibility.models.SafetyLevel.SearchMixerSrpMinimal +import com.twitter.visibility.models.SafetyLevel.SearchMixerSrpStrict +import com.twitter.visibility.models.SafetyLevel.SearchPeopleSrp +import com.twitter.visibility.models.SafetyLevel.SearchPeopleTypeahead +import com.twitter.visibility.models.SafetyLevel.SearchPhoto +import com.twitter.visibility.models.SafetyLevel.SearchTop +import com.twitter.visibility.models.SafetyLevel.SearchTopQig +import com.twitter.visibility.models.SafetyLevel.SearchTrendTakeoverPromotedTweet +import com.twitter.visibility.models.SafetyLevel.SearchVideo +import com.twitter.visibility.models.SafetyLevel.ShoppingManagerSpyMode +import com.twitter.visibility.models.SafetyLevel.SignalsReactions +import com.twitter.visibility.models.SafetyLevel.SignalsTweetReactingUsers +import com.twitter.visibility.models.SafetyLevel.SoftInterventionPivot +import com.twitter.visibility.models.SafetyLevel.SpaceFleetline +import com.twitter.visibility.models.SafetyLevel.SpaceHomeTimelineUpranking +import com.twitter.visibility.models.SafetyLevel.SpaceJoinScreen +import com.twitter.visibility.models.SafetyLevel.SpaceNotifications +import com.twitter.visibility.models.SafetyLevel.SpaceTweetAvatarHomeTimeline +import com.twitter.visibility.models.SafetyLevel.SpacesParticipants +import com.twitter.visibility.models.SafetyLevel.SpacesSellerApplicationStatus +import com.twitter.visibility.models.SafetyLevel.SpacesSharing +import com.twitter.visibility.models.SafetyLevel.StickersTimeline +import com.twitter.visibility.models.SafetyLevel.StratoExtLimitedEngagements +import com.twitter.visibility.models.SafetyLevel.StreamServices +import com.twitter.visibility.models.SafetyLevel.SuperFollowerConnections +import com.twitter.visibility.models.SafetyLevel.SuperLike +import com.twitter.visibility.models.SafetyLevel.Test +import com.twitter.visibility.models.SafetyLevel.TimelineBookmark +import com.twitter.visibility.models.SafetyLevel.TimelineContentControls +import com.twitter.visibility.models.SafetyLevel.TimelineConversations +import com.twitter.visibility.models.SafetyLevel.TimelineConversationsDownranking +import com.twitter.visibility.models.SafetyLevel.TimelineConversationsDownrankingMinimal +import com.twitter.visibility.models.SafetyLevel.TimelineFavorites +import com.twitter.visibility.models.SafetyLevel.TimelineFavoritesSelfView +import com.twitter.visibility.models.SafetyLevel.TimelineFocalTweet +import com.twitter.visibility.models.SafetyLevel.TimelineFollowingActivity +import com.twitter.visibility.models.SafetyLevel.TimelineHomeCommunities +import com.twitter.visibility.models.SafetyLevel.TimelineHomeHydration +import com.twitter.visibility.models.SafetyLevel.TimelineHomeLatest +import com.twitter.visibility.models.SafetyLevel.TimelineHomePromotedHydration +import com.twitter.visibility.models.SafetyLevel.TimelineHomeRecommendations +import com.twitter.visibility.models.SafetyLevel.TimelineHomeTopicFollowRecommendations +import com.twitter.visibility.models.SafetyLevel.TimelineInjection +import com.twitter.visibility.models.SafetyLevel.TimelineLikedBy +import com.twitter.visibility.models.SafetyLevel.TimelineLists +import com.twitter.visibility.models.SafetyLevel.TimelineMedia +import com.twitter.visibility.models.SafetyLevel.TimelineMentions +import com.twitter.visibility.models.SafetyLevel.TimelineModeratedTweetsHydration +import com.twitter.visibility.models.SafetyLevel.TimelineProfileAll +import com.twitter.visibility.models.SafetyLevel.TimelineProfileSpaces +import com.twitter.visibility.models.SafetyLevel.TimelineProfileSuperFollows +import com.twitter.visibility.models.SafetyLevel.TimelineScorer +import com.twitter.visibility.models.SafetyLevel.Tombstoning +import com.twitter.visibility.models.SafetyLevel.TopicsLandingPageTopicRecommendations +import com.twitter.visibility.models.SafetyLevel.TrendsRepresentativeTweet +import com.twitter.visibility.models.SafetyLevel.TrustedFriendsUserList +import com.twitter.visibility.models.SafetyLevel.TweetDetail +import com.twitter.visibility.models.SafetyLevel.TweetDetailNonToo +import com.twitter.visibility.models.SafetyLevel.TweetDetailWithInjectionsHydration +import com.twitter.visibility.models.SafetyLevel.TweetEngagers +import com.twitter.visibility.models.SafetyLevel.TweetReplyNudge +import com.twitter.visibility.models.SafetyLevel.TweetWritesApi +import com.twitter.visibility.models.SafetyLevel.TwitterArticleCompose +import com.twitter.visibility.models.SafetyLevel.TwitterArticleProfileTab +import com.twitter.visibility.models.SafetyLevel.TwitterArticleRead +import com.twitter.visibility.models.SafetyLevel.UserMilestoneRecommendation +import com.twitter.visibility.models.SafetyLevel.UserProfileHeader +import com.twitter.visibility.models.SafetyLevel.UserSelfViewOnly +import com.twitter.visibility.models.SafetyLevel.UserSettings +import com.twitter.visibility.models.SafetyLevel.VideoAds +import com.twitter.visibility.models.SafetyLevel.WritePathLimitedActionsEnforcement +import com.twitter.visibility.models.SafetyLevel.ZipbirdConsumerArchives + +sealed trait SafetyLevelGroup { val levels: Set[SafetyLevel] } + +object SafetyLevelGroup { + case object Ads extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + AdsBusinessSettings, + AdsCampaign, + AdsManager, + AdsReportingDashboard, + BrandSafety, + VideoAds, + QuickPromoteTweetEligibility + ) + } + + case object ArticleTimeline extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + ArticleTweetTimeline, + ) + } + + case object ArticleTweets extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + TwitterArticleCompose, + TwitterArticleProfileTab, + TwitterArticleRead, + ) + } + + case object Birdwatch extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + BirdwatchNoteAuthor, + BirdwatchNoteTweetsTimeline, + BirdwatchNeedsYourHelpNotifications, + ) + } + + case object Cards extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + CardPollVoting, + CardsService, + ) + } + + case object Communities extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + SafetyLevel.Communities + ) + } + + case object Conversation extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + ConversationFocalPrehydration, + ConversationFocalTweet, + ConversationInjectedTweet, + ConversationReply, + Tombstoning, + ) + } + + case object CreativeContainerService extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + AccessInternalPromotedContent + ) + } + + case object Des extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + DevPlatformGetListTweets, + DesFollowingAndFollowersUserList, + DesHomeTimeline, + DesQuoteTweetTimeline, + DesRetweetingUsers, + DesTweetDetail, + DesTweetLikingUsers, + DesUserBookmarks, + DesUserLikedTweets, + DesUserMentions, + DesUserTweets, + DevPlatformComplianceStream, + ) + } + + case object DesStream extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + DesRealtime, + DesRealtimeSpamEnrichment, + DesRealtimeTweetFilter, + ) + } + + case object Dm extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + DirectMessages, + DirectMessagesConversationList, + DirectMessagesConversationTimeline, + DirectMessagesInbox, + DirectMessagesMutedUsers, + DirectMessagesPinned, + DirectMessagesSearch, + ) + } + + case object Followers extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + FollowedTopicsTimeline, + FollowerConnections, + FollowingAndFollowersUserList, + FriendsFollowingList, + ) + } + + case object Graphql extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + FilterDefault, + GraphqlDefault, + SoftInterventionPivot, + ) + } + + case object Jiminy extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + HumanizationNudge, + TweetReplyNudge, + ) + } + + case object Lists extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + AllSubscribedLists, + ListHeader, + ListMemberships, + ListOwnerships, + ListRecommendations, + ListSearch, + ListSubscriptions, + ) + } + + case object Notifications extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + NotificationsIbis, + NotificationsPlatform, + NotificationsPlatformPush, + NotificationsQig, + NotificationsRead, + NotificationsTimelineDeviceFollow, + NotificationsWrite, + NotificationsWriterTweetHydrator, + NotificationsWriterV2, + ) + } + + case object Other extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + CuratedTrendsRepresentativeTweet, + CurationPolicyViolations, + BaseQig, + Appeals, + ContentControlToolInstall, + EditHistoryTimeline, + ElevatedQuoteTweetTimeline, + EmbeddedTweet, + EmbedsPublicInterestNotice, + EmbedTweetMarkup, + ExploreRecommendations, + WritePathLimitedActionsEnforcement, + LiveVideoTimeline, + LivePipelineEngagementCounts, + Minimal, + Moments, + NearbyTimeline, + NewUserExperience, + QuoteTweetTimeline, + QuotedTweetRules, + ReportCenter, + Revenue, + ShoppingManagerSpyMode, + StickersTimeline, + SuperLike, + TrendsRepresentativeTweet, + TrustedFriendsUserList, + GryphonDecksAndColumns, + TweetEngagers, + TweetWritesApi, + UserMilestoneRecommendation, + StreamServices, + ZipbirdConsumerArchives + ) + } + + case object Profile extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + UserProfileHeader, + UserSelfViewOnly, + UserSettings, + ) + } + + case object ProfileMixer extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + ProfileMixerMedia, + ProfileMixerFavorites, + ) + } + + case object Reactions extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + SignalsReactions, + SignalsTweetReactingUsers, + ) + } + + case object Recommendations extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + MagicRecs, + MagicRecsV2, + MagicRecsAggressive, + MagicRecsAggressiveV2, + SafetyLevel.Recommendations, + RecosVideo, + RecosWritePath, + ) + } + + case object ReturningUserExperience extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + SafetyLevel.ReturningUserExperience, + ReturningUserExperienceFocalTweet, + ) + } + + case object SafeSearch extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + SafeSearchMinimal, + SafeSearchStrict, + ) + } + + case object Search extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + SearchHydration, + SearchLatest, + SearchTop, + SearchTopQig, + SearchPeopleSrp, + SearchPeopleTypeahead, + SearchPhoto, + SearchTrendTakeoverPromotedTweet, + SearchVideo, + SearchBlenderUserRules, + SearchLatestUserRules, + ) + } + + case object SearchMixer extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + SearchMixerSrpMinimal, + SearchMixerSrpStrict, + ) + } + + case object Socialproof extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + SafetyLevel.SocialProof + ) + } + + case object Spaces extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + SpaceFleetline, + SpaceHomeTimelineUpranking, + SpaceJoinScreen, + SpaceNotifications, + SafetyLevel.Spaces, + SpacesParticipants, + SpacesSellerApplicationStatus, + SpacesSharing, + SpaceTweetAvatarHomeTimeline, + ) + } + + case object Strato extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + StratoExtLimitedEngagements + ) + } + + case object Superfollows extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + SuperFollowerConnections, + TimelineProfileSuperFollows, + ) + } + + case object Testing extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + ForDevelopmentOnly, + KitchenSinkDevelopment, + Test, + ) + } + + case object Timeline extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + BlockMuteUsersTimeline, + TimelineBookmark, + TimelineContentControls, + TimelineConversationsDownranking, + TimelineConversationsDownrankingMinimal, + TimelineFavorites, + TimelineFavoritesSelfView, + TimelineFollowingActivity, + TimelineScorer, + TimelineInjection, + TimelineLikedBy, + TimelineLists, + TimelineMedia, + TimelineMentions, + ModeratedTweetsTimeline, + TimelineModeratedTweetsHydration, + ) + } + + case object TopicRecommendations extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + SafetyLevel.TopicRecommendations, + TopicsLandingPageTopicRecommendations, + ) + } + + case object TimelineProfile extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + SafetyLevel.TimelineProfile, + TimelineProfileAll, + TimelineProfileSpaces, + ) + } + + case object TimelineHome extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + SafetyLevel.TimelineHome, + TimelineHomeCommunities, + TimelineHomeHydration, + TimelineHomeLatest, + TimelineHomePromotedHydration, + TimelineHomeRecommendations, + TimelineHomeTopicFollowRecommendations, + ) + } + + case object TlsApi extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + TimelineConversations, + TimelineFocalTweet, + ) + } + + case object TweetDetails extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + TweetDetail, + TweetDetailNonToo, + TweetDetailWithInjectionsHydration, + RepliesGrouping, + ) + } + + case object Special extends SafetyLevelGroup { + override val levels: Set[SafetyLevel] = Set( + FilterAll, + FilterAllPlaceholder, + FilterNone, + ) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/SemanticCoreAnnotation.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/SemanticCoreAnnotation.scala new file mode 100644 index 000000000..c3d651422 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/SemanticCoreAnnotation.scala @@ -0,0 +1,3 @@ +package com.twitter.visibility.models + +case class SemanticCoreAnnotation(groupId: Long, domainId: Long, entityId: Long) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/SpaceSafetyLabelType.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/SpaceSafetyLabelType.scala new file mode 100644 index 000000000..432650dfd --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/SpaceSafetyLabelType.scala @@ -0,0 +1,95 @@ +package com.twitter.visibility.models + +import com.twitter.visibility.safety_label_store.{thriftscala => s} +import com.twitter.visibility.util.NamingUtils + +sealed trait SpaceSafetyLabelType extends SafetyLabelType { + lazy val name: String = NamingUtils.getFriendlyName(this) +} + +object SpaceSafetyLabelType extends SafetyLabelType { + + val List: List[SpaceSafetyLabelType] = s.SpaceSafetyLabelType.list.map(fromThrift) + + val ActiveLabels: List[SpaceSafetyLabelType] = List.filter { labelType => + labelType != Unknown && labelType != Deprecated + } + + private lazy val nameToValueMap: Map[String, SpaceSafetyLabelType] = + List.map(l => l.name.toLowerCase -> l).toMap + def fromName(name: String): Option[SpaceSafetyLabelType] = nameToValueMap.get(name.toLowerCase) + + private val UnknownThriftSafetyLabelType = + s.SpaceSafetyLabelType.EnumUnknownSpaceSafetyLabelType(UnknownEnumValue) + + private lazy val thriftToModelMap: Map[s.SpaceSafetyLabelType, SpaceSafetyLabelType] = Map( + s.SpaceSafetyLabelType.DoNotAmplify -> DoNotAmplify, + s.SpaceSafetyLabelType.CoordinatedHarmfulActivityHighRecall -> CoordinatedHarmfulActivityHighRecall, + s.SpaceSafetyLabelType.UntrustedUrl -> UntrustedUrl, + s.SpaceSafetyLabelType.MisleadingHighRecall -> MisleadingHighRecall, + s.SpaceSafetyLabelType.NsfwHighPrecision -> NsfwHighPrecision, + s.SpaceSafetyLabelType.NsfwHighRecall -> NsfwHighRecall, + s.SpaceSafetyLabelType.CivicIntegrityMisinfo -> CivicIntegrityMisinfo, + s.SpaceSafetyLabelType.MedicalMisinfo -> MedicalMisinfo, + s.SpaceSafetyLabelType.GenericMisinfo -> GenericMisinfo, + s.SpaceSafetyLabelType.DmcaWithheld -> DmcaWithheld, + s.SpaceSafetyLabelType.HatefulHighRecall -> HatefulHighRecall, + s.SpaceSafetyLabelType.ViolenceHighRecall -> ViolenceHighRecall, + s.SpaceSafetyLabelType.HighToxicityModelScore -> HighToxicityModelScore, + s.SpaceSafetyLabelType.UkraineCrisisTopic -> UkraineCrisisTopic, + s.SpaceSafetyLabelType.DoNotPublicPublish -> DoNotPublicPublish, + s.SpaceSafetyLabelType.Reserved16 -> Deprecated, + s.SpaceSafetyLabelType.Reserved17 -> Deprecated, + s.SpaceSafetyLabelType.Reserved18 -> Deprecated, + s.SpaceSafetyLabelType.Reserved19 -> Deprecated, + s.SpaceSafetyLabelType.Reserved20 -> Deprecated, + s.SpaceSafetyLabelType.Reserved21 -> Deprecated, + s.SpaceSafetyLabelType.Reserved22 -> Deprecated, + s.SpaceSafetyLabelType.Reserved23 -> Deprecated, + s.SpaceSafetyLabelType.Reserved24 -> Deprecated, + s.SpaceSafetyLabelType.Reserved25 -> Deprecated, + ) + + private lazy val modelToThriftMap: Map[SpaceSafetyLabelType, s.SpaceSafetyLabelType] = + (for ((k, v) <- thriftToModelMap) yield (v, k)) ++ Map( + Deprecated -> s.SpaceSafetyLabelType.EnumUnknownSpaceSafetyLabelType(DeprecatedEnumValue), + ) + + case object DoNotAmplify extends SpaceSafetyLabelType + case object CoordinatedHarmfulActivityHighRecall extends SpaceSafetyLabelType + case object UntrustedUrl extends SpaceSafetyLabelType + case object MisleadingHighRecall extends SpaceSafetyLabelType + case object NsfwHighPrecision extends SpaceSafetyLabelType + case object NsfwHighRecall extends SpaceSafetyLabelType + case object CivicIntegrityMisinfo extends SpaceSafetyLabelType + case object MedicalMisinfo extends SpaceSafetyLabelType + case object GenericMisinfo extends SpaceSafetyLabelType + case object DmcaWithheld extends SpaceSafetyLabelType + case object HatefulHighRecall extends SpaceSafetyLabelType + case object ViolenceHighRecall extends SpaceSafetyLabelType + case object HighToxicityModelScore extends SpaceSafetyLabelType + + case object UkraineCrisisTopic extends SpaceSafetyLabelType + + case object DoNotPublicPublish extends SpaceSafetyLabelType + + case object Deprecated extends SpaceSafetyLabelType + case object Unknown extends SpaceSafetyLabelType + + def fromThrift(safetyLabelType: s.SpaceSafetyLabelType): SpaceSafetyLabelType = + thriftToModelMap.get(safetyLabelType) match { + case Some(spaceSafetyLabelType) => spaceSafetyLabelType + case _ => + safetyLabelType match { + case s.SpaceSafetyLabelType.EnumUnknownSpaceSafetyLabelType(DeprecatedEnumValue) => + Deprecated + case _ => + Unknown + } + } + + def toThrift(safetyLabelType: SpaceSafetyLabelType): s.SpaceSafetyLabelType = { + modelToThriftMap + .get(safetyLabelType).getOrElse(UnknownThriftSafetyLabelType) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/TweetDeleteReason.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/TweetDeleteReason.scala new file mode 100644 index 000000000..26cfbe709 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/TweetDeleteReason.scala @@ -0,0 +1,6 @@ +package com.twitter.visibility.models + +object TweetDeleteReason extends Enumeration { + type TweetDeleteReason = Value + val Deleted, BounceDeleted = Value +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/TweetModelMetadata.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/TweetModelMetadata.scala new file mode 100644 index 000000000..e6303a03d --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/TweetModelMetadata.scala @@ -0,0 +1,23 @@ +package com.twitter.visibility.models + +import com.twitter.spam.rtf.{thriftscala => s} + +case class TweetModelMetadata( + version: Option[Int] = None, + calibratedLanguage: Option[String] = None) + +object TweetModelMetadata { + + def fromThrift(metadata: s.ModelMetadata): Option[TweetModelMetadata] = { + metadata match { + case s.ModelMetadata.ModelMetadataV1(s.ModelMetadataV1(version, calibratedLanguage)) => + Some(TweetModelMetadata(version, calibratedLanguage)) + case _ => None + } + } + + def toThrift(metadata: TweetModelMetadata): s.ModelMetadata = { + s.ModelMetadata.ModelMetadataV1( + s.ModelMetadataV1(metadata.version, metadata.calibratedLanguage)) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/TweetSafetyLabel.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/TweetSafetyLabel.scala new file mode 100644 index 000000000..b830eb5e2 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/TweetSafetyLabel.scala @@ -0,0 +1,360 @@ +package com.twitter.visibility.models + +import com.twitter.spam.rtf.thriftscala.SafetyLabelSource +import com.twitter.spam.rtf.{thriftscala => s} +import com.twitter.util.Time +import com.twitter.visibility.util.NamingUtils + +sealed trait TweetSafetyLabelType extends SafetyLabelType with Product with Serializable { + lazy val name: String = NamingUtils.getFriendlyName(this) +} + +case class TweetSafetyLabel( + labelType: TweetSafetyLabelType, + source: Option[LabelSource] = None, + applicableUsers: Set[Long] = Set.empty, + modelMetadata: Option[TweetModelMetadata] = None, + score: Option[Double] = None, + safetyLabelSource: Option[SafetyLabelSource] = None) + +object TweetSafetyLabelType extends SafetyLabelType { + + val List: List[TweetSafetyLabelType] = s.SafetyLabelType.list.map(fromThrift) + + val ActiveLabels: List[TweetSafetyLabelType] = List.filter { labelType => + labelType != Unknown && labelType != Deprecated + } + + private lazy val nameToValueMap: Map[String, TweetSafetyLabelType] = + List.map(l => l.name.toLowerCase -> l).toMap + def fromName(name: String): Option[TweetSafetyLabelType] = nameToValueMap.get(name.toLowerCase) + + private val UnknownThriftSafetyLabelType = + s.SafetyLabelType.EnumUnknownSafetyLabelType(UnknownEnumValue) + + private lazy val thriftToModelMap: Map[s.SafetyLabelType, TweetSafetyLabelType] = Map( + s.SafetyLabelType.Abusive -> Abusive, + s.SafetyLabelType.AbusiveBehavior -> AbusiveBehavior, + s.SafetyLabelType.AbusiveBehaviorInsults -> AbusiveBehaviorInsults, + s.SafetyLabelType.AbusiveBehaviorViolentThreat -> AbusiveBehaviorViolentThreat, + s.SafetyLabelType.AbusiveBehaviorMajorAbuse -> AbusiveBehaviorMajorAbuse, + s.SafetyLabelType.AbusiveHighRecall -> AbusiveHighRecall, + s.SafetyLabelType.AdsManagerDenyList -> AdsManagerDenyList, + s.SafetyLabelType.AgathaSpam -> AgathaSpam, + s.SafetyLabelType.Automation -> Automation, + s.SafetyLabelType.AutomationHighRecall -> AutomationHighRecall, + s.SafetyLabelType.Bounce -> Bounce, + s.SafetyLabelType.BounceEdits -> BounceEdits, + s.SafetyLabelType.BrandSafetyNsfaAggregate -> BrandSafetyNsfaAggregate, + s.SafetyLabelType.BrandSafetyExperimental1 -> BrandSafetyExperimental1, + s.SafetyLabelType.BrandSafetyExperimental2 -> BrandSafetyExperimental2, + s.SafetyLabelType.BrandSafetyExperimental3 -> BrandSafetyExperimental3, + s.SafetyLabelType.BrandSafetyExperimental4 -> BrandSafetyExperimental4, + s.SafetyLabelType.BystanderAbusive -> BystanderAbusive, + s.SafetyLabelType.CopypastaSpam -> CopypastaSpam, + s.SafetyLabelType.DoNotAmplify -> DoNotAmplify, + s.SafetyLabelType.DownrankSpamReply -> DownrankSpamReply, + s.SafetyLabelType.DuplicateContent -> DuplicateContent, + s.SafetyLabelType.DuplicateMention -> DuplicateMention, + s.SafetyLabelType.DynamicProductAd -> DynamicProductAd, + s.SafetyLabelType.EdiDevelopmentOnly -> EdiDevelopmentOnly, + s.SafetyLabelType.ExperimentalNudge -> ExperimentalNudge, + s.SafetyLabelType.ExperimentalSensitiveIllegal2 -> ExperimentalSensitiveIllegal2, + s.SafetyLabelType.ForEmergencyUseOnly -> ForEmergencyUseOnly, + s.SafetyLabelType.GoreAndViolence -> GoreAndViolence, + s.SafetyLabelType.GoreAndViolenceHighPrecision -> GoreAndViolenceHighPrecision, + s.SafetyLabelType.GoreAndViolenceHighRecall -> GoreAndViolenceHighRecall, + s.SafetyLabelType.GoreAndViolenceReportedHeuristics -> GoreAndViolenceReportedHeuristics, + s.SafetyLabelType.GoreAndViolenceTopicHighRecall -> GoreAndViolenceTopicHighRecall, + s.SafetyLabelType.HatefulConduct -> HatefulConduct, + s.SafetyLabelType.HatefulConductViolentThreat -> HatefulConductViolentThreat, + s.SafetyLabelType.HighCryptospamScore -> HighCryptospamScore, + s.SafetyLabelType.HighPReportedTweetScore -> HighPReportedTweetScore, + s.SafetyLabelType.HighPSpammyTweetScore -> HighPSpammyTweetScore, + s.SafetyLabelType.HighPblockScore -> HighPblockScore, + s.SafetyLabelType.HighProactiveTosScore -> HighProactiveTosScore, + s.SafetyLabelType.HighSpammyTweetContentScore -> HighSpammyTweetContentScore, + s.SafetyLabelType.HighToxicityScore -> HighToxicityScore, + s.SafetyLabelType.HighlyReportedAndMidhighToxicityScore -> HighlyReportedAndMidhighToxicityScore, + s.SafetyLabelType.HighlyReportedTweet -> HighlyReportedTweet, + s.SafetyLabelType.InterstitialDevelopmentOnly -> InterstitialDevelopmentOnly, + s.SafetyLabelType.IpiDevelopmentOnly -> IpiDevelopmentOnly, + s.SafetyLabelType.LiveLowQuality -> LiveLowQuality, + s.SafetyLabelType.LowQuality -> LowQuality, + s.SafetyLabelType.LowQualityMention -> LowQualityMention, + s.SafetyLabelType.MisinfoCivic -> MisinfoCivic, + s.SafetyLabelType.MisinfoCrisis -> MisinfoCrisis, + s.SafetyLabelType.MisinfoGeneric -> MisinfoGeneric, + s.SafetyLabelType.MisinfoMedical -> MisinfoMedical, + s.SafetyLabelType.NsfaHighPrecision -> NsfaHighPrecision, + s.SafetyLabelType.NsfaHighRecall -> NsfaHighRecall, + s.SafetyLabelType.NsfwCardImage -> NsfwCardImage, + s.SafetyLabelType.NsfwHighPrecision -> NsfwHighPrecision, + s.SafetyLabelType.NsfwHighRecall -> NsfwHighRecall, + s.SafetyLabelType.NsfwReportedHeuristics -> NsfwReportedHeuristics, + s.SafetyLabelType.NsfwText -> NsfwText, + s.SafetyLabelType.NsfwTextHighPrecision -> NsfwTextHighPrecision, + s.SafetyLabelType.NsfwVideo -> NsfwVideo, + s.SafetyLabelType.PNegMultimodalHighPrecision -> PNegMultimodalHighPrecision, + s.SafetyLabelType.PNegMultimodalHighRecall -> PNegMultimodalHighRecall, + s.SafetyLabelType.Pdna -> Pdna, + s.SafetyLabelType.RecommendationsLowQuality -> RecommendationsLowQuality, + s.SafetyLabelType.RitoActionedTweet -> RitoActionedTweet, + s.SafetyLabelType.SafetyCrisis -> SafetyCrisis, + s.SafetyLabelType.SearchBlacklist -> SearchBlacklist, + s.SafetyLabelType.SearchBlacklistHighRecall -> SearchBlacklistHighRecall, + s.SafetyLabelType.SemanticCoreMisinformation -> SemanticCoreMisinformation, + s.SafetyLabelType.SmyteSpamTweet -> SmyteSpamTweet, + s.SafetyLabelType.Spam -> Spam, + s.SafetyLabelType.SpamHighRecall -> SpamHighRecall, + s.SafetyLabelType.TombstoneDevelopmentOnly -> TombstoneDevelopmentOnly, + s.SafetyLabelType.TweetContainsHatefulConductSlurHighSeverity -> TweetContainsHatefulConductSlurHighSeverity, + s.SafetyLabelType.TweetContainsHatefulConductSlurMediumSeverity -> TweetContainsHatefulConductSlurMediumSeverity, + s.SafetyLabelType.TweetContainsHatefulConductSlurLowSeverity -> TweetContainsHatefulConductSlurLowSeverity, + s.SafetyLabelType.UnsafeUrl -> UnsafeUrl, + s.SafetyLabelType.UntrustedUrl -> UntrustedUrl, + s.SafetyLabelType.FosnrHatefulConduct -> FosnrHatefulConduct, + s.SafetyLabelType.FosnrHatefulConductLowSeveritySlur -> FosnrHatefulConductLowSeveritySlur, + s.SafetyLabelType.AbusiveHighRecall2 -> Deprecated, + s.SafetyLabelType.AbusiveHighRecall3 -> Deprecated, + s.SafetyLabelType.BrazilianPoliticalTweet -> Deprecated, + s.SafetyLabelType.BystanderAbusive2 -> Deprecated, + s.SafetyLabelType.BystanderAbusive3 -> Deprecated, + s.SafetyLabelType.DeprecatedLabel144 -> Deprecated, + s.SafetyLabelType.Experimental10Seh -> Deprecated, + s.SafetyLabelType.Experimental11Seh -> Deprecated, + s.SafetyLabelType.Experimental12Seh -> Deprecated, + s.SafetyLabelType.Experimental13Seh -> Deprecated, + s.SafetyLabelType.Experimental14Seh -> Deprecated, + s.SafetyLabelType.Experimental15Seh -> Deprecated, + s.SafetyLabelType.Experimental16Seh -> Deprecated, + s.SafetyLabelType.Experimental17Seh -> Deprecated, + s.SafetyLabelType.Experimental18Seh -> Deprecated, + s.SafetyLabelType.Experimental19Seh -> Deprecated, + s.SafetyLabelType.Experimental1Seh -> Deprecated, + s.SafetyLabelType.Experimental20Seh -> Deprecated, + s.SafetyLabelType.Experimental21Seh -> Deprecated, + s.SafetyLabelType.Experimental22Seh -> Deprecated, + s.SafetyLabelType.Experimental23Seh -> Deprecated, + s.SafetyLabelType.Experimental24Seh -> Deprecated, + s.SafetyLabelType.Experimental25Seh -> Deprecated, + s.SafetyLabelType.Experimental2Seh -> Deprecated, + s.SafetyLabelType.Experimental3Seh -> Deprecated, + s.SafetyLabelType.Experimental4Seh -> Deprecated, + s.SafetyLabelType.Experimental5Seh -> Deprecated, + s.SafetyLabelType.Experimental6Seh -> Deprecated, + s.SafetyLabelType.Experimental7Seh -> Deprecated, + s.SafetyLabelType.Experimental8Seh -> Deprecated, + s.SafetyLabelType.Experimental9Seh -> Deprecated, + s.SafetyLabelType.ExperimentalHighHealthModelScore1 -> Deprecated, + s.SafetyLabelType.ExperimentalHighHealthModelScore10 -> Deprecated, + s.SafetyLabelType.ExperimentalHighHealthModelScore2 -> Deprecated, + s.SafetyLabelType.ExperimentalHighHealthModelScore3 -> Deprecated, + s.SafetyLabelType.ExperimentalHighHealthModelScore4 -> Deprecated, + s.SafetyLabelType.ExperimentalHighHealthModelScore5 -> Deprecated, + s.SafetyLabelType.ExperimentalHighHealthModelScore6 -> Deprecated, + s.SafetyLabelType.ExperimentalHighHealthModelScore7 -> Deprecated, + s.SafetyLabelType.ExperimentalHighHealthModelScore8 -> Deprecated, + s.SafetyLabelType.ExperimentalHighHealthModelScore9 -> Deprecated, + s.SafetyLabelType.ExperimentalSensitiveIllegal1 -> Deprecated, + s.SafetyLabelType.ExperimentalSensitiveIllegal3 -> Deprecated, + s.SafetyLabelType.ExperimentalSensitiveIllegal4 -> Deprecated, + s.SafetyLabelType.ExperimentalSensitiveIllegal5 -> Deprecated, + s.SafetyLabelType.ExperimentalSensitiveIllegal6 -> Deprecated, + s.SafetyLabelType.ExperimentalSpam1 -> Deprecated, + s.SafetyLabelType.ExperimentalSpam2 -> Deprecated, + s.SafetyLabelType.ExperimentalSpam3 -> Deprecated, + s.SafetyLabelType.Experimentation -> Deprecated, + s.SafetyLabelType.Experimentation2 -> Deprecated, + s.SafetyLabelType.Experimentation3 -> Deprecated, + s.SafetyLabelType.HighlyReportedImage -> Deprecated, + s.SafetyLabelType.HighToxicityHoldbackModelScore -> Deprecated, + s.SafetyLabelType.LowQualityHighRecall -> Deprecated, + s.SafetyLabelType.MagicRecsDenylist -> Deprecated, + s.SafetyLabelType.MisinfoCovid19 -> Deprecated, + s.SafetyLabelType.MsnfoBrazilianElection -> Deprecated, + s.SafetyLabelType.MsnfoCovid19Vaccine -> Deprecated, + s.SafetyLabelType.MsnfoFrenchElection -> Deprecated, + s.SafetyLabelType.MsnfoPhilippineElection -> Deprecated, + s.SafetyLabelType.MsnfoUsElection -> Deprecated, + s.SafetyLabelType.NsfwNearPerfect -> Deprecated, + s.SafetyLabelType.PersonaNonGrata -> Deprecated, + s.SafetyLabelType.PMisinfoCombined15 -> Deprecated, + s.SafetyLabelType.PMisinfoCombined30 -> Deprecated, + s.SafetyLabelType.PMisinfoCombined50 -> Deprecated, + s.SafetyLabelType.PMisinfoDenylist -> Deprecated, + s.SafetyLabelType.PMisinfoPVeracityNudge -> Deprecated, + s.SafetyLabelType.PoliticalTweetExperimental1 -> Deprecated, + s.SafetyLabelType.ProactiveTosHighRecall -> Deprecated, + s.SafetyLabelType.ProactiveTosHighRecallContainsSelfHarm -> Deprecated, + s.SafetyLabelType.ProactiveTosHighRecallEncourageSelfHarm -> Deprecated, + s.SafetyLabelType.ProactiveTosHighRecallEpisodic -> Deprecated, + s.SafetyLabelType.ProactiveTosHighRecallEpisodicHatefulConduct -> Deprecated, + s.SafetyLabelType.ProactiveTosHighRecallOtherAbusePolicy -> Deprecated, + s.SafetyLabelType.ProjectLibra -> Deprecated, + s.SafetyLabelType.SearchHighVisibilityDenylist -> Deprecated, + s.SafetyLabelType.SearchHighVisibilityHighRecallDenylist -> Deprecated, + s.SafetyLabelType.Reserved162 -> Deprecated, + s.SafetyLabelType.Reserved163 -> Deprecated, + s.SafetyLabelType.Reserved164 -> Deprecated, + s.SafetyLabelType.Reserved165 -> Deprecated, + s.SafetyLabelType.Reserved166 -> Deprecated, + s.SafetyLabelType.Reserved167 -> Deprecated, + s.SafetyLabelType.Reserved168 -> Deprecated, + s.SafetyLabelType.Reserved169 -> Deprecated, + s.SafetyLabelType.Reserved170 -> Deprecated, + ) + + private lazy val modelToThriftMap: Map[TweetSafetyLabelType, s.SafetyLabelType] = + (for ((k, v) <- thriftToModelMap) yield (v, k)) ++ Map( + Deprecated -> s.SafetyLabelType.EnumUnknownSafetyLabelType(DeprecatedEnumValue), + ) + + case object Abusive extends TweetSafetyLabelType + case object AbusiveBehavior extends TweetSafetyLabelType + case object AbusiveBehaviorInsults extends TweetSafetyLabelType + case object AbusiveBehaviorViolentThreat extends TweetSafetyLabelType + case object AbusiveBehaviorMajorAbuse extends TweetSafetyLabelType + case object AbusiveHighRecall extends TweetSafetyLabelType + case object Automation extends TweetSafetyLabelType + case object AutomationHighRecall extends TweetSafetyLabelType + case object Bounce extends TweetSafetyLabelType + case object BystanderAbusive extends TweetSafetyLabelType + case object NsfaHighRecall extends TweetSafetyLabelType + case object DuplicateContent extends TweetSafetyLabelType + case object DuplicateMention extends TweetSafetyLabelType + case object GoreAndViolence extends TweetSafetyLabelType { + + val DeprecatedAt: Time = Time.at("2019-09-12 00:00:00 UTC") + } + case object GoreAndViolenceHighRecall extends TweetSafetyLabelType + case object LiveLowQuality extends TweetSafetyLabelType + case object LowQuality extends TweetSafetyLabelType + case object LowQualityMention extends TweetSafetyLabelType + case object NsfwCardImage extends TweetSafetyLabelType + case object NsfwHighRecall extends TweetSafetyLabelType + case object NsfwHighPrecision extends TweetSafetyLabelType + case object NsfwVideo extends TweetSafetyLabelType + case object Pdna extends TweetSafetyLabelType + + case object RecommendationsLowQuality extends TweetSafetyLabelType + case object SearchBlacklist extends TweetSafetyLabelType + case object Spam extends TweetSafetyLabelType + case object SpamHighRecall extends TweetSafetyLabelType + case object UntrustedUrl extends TweetSafetyLabelType + case object HighToxicityScore extends TweetSafetyLabelType + case object HighPblockScore extends TweetSafetyLabelType + case object SearchBlacklistHighRecall extends TweetSafetyLabelType + case object ForEmergencyUseOnly extends TweetSafetyLabelType + case object HighProactiveTosScore extends TweetSafetyLabelType + case object SafetyCrisis extends TweetSafetyLabelType + case object MisinfoCivic extends TweetSafetyLabelType + case object MisinfoCrisis extends TweetSafetyLabelType + case object MisinfoGeneric extends TweetSafetyLabelType + case object MisinfoMedical extends TweetSafetyLabelType + case object AdsManagerDenyList extends TweetSafetyLabelType + case object GoreAndViolenceHighPrecision extends TweetSafetyLabelType + case object NsfwReportedHeuristics extends TweetSafetyLabelType + case object GoreAndViolenceReportedHeuristics extends TweetSafetyLabelType + case object HighPSpammyTweetScore extends TweetSafetyLabelType + case object DoNotAmplify extends TweetSafetyLabelType + case object HighlyReportedTweet extends TweetSafetyLabelType + case object AgathaSpam extends TweetSafetyLabelType + case object SmyteSpamTweet extends TweetSafetyLabelType + case object SemanticCoreMisinformation extends TweetSafetyLabelType + case object HighPReportedTweetScore extends TweetSafetyLabelType + case object HighSpammyTweetContentScore extends TweetSafetyLabelType + case object GoreAndViolenceTopicHighRecall extends TweetSafetyLabelType + case object CopypastaSpam extends TweetSafetyLabelType + case object ExperimentalSensitiveIllegal2 extends TweetSafetyLabelType + case object DownrankSpamReply extends TweetSafetyLabelType + case object NsfwText extends TweetSafetyLabelType + case object HighlyReportedAndMidhighToxicityScore extends TweetSafetyLabelType + case object DynamicProductAd extends TweetSafetyLabelType + case object TombstoneDevelopmentOnly extends TweetSafetyLabelType + case object TweetContainsHatefulConductSlurHighSeverity extends TweetSafetyLabelType + case object TweetContainsHatefulConductSlurMediumSeverity extends TweetSafetyLabelType + case object TweetContainsHatefulConductSlurLowSeverity extends TweetSafetyLabelType + case object RitoActionedTweet extends TweetSafetyLabelType + case object ExperimentalNudge extends TweetSafetyLabelType + case object PNegMultimodalHighPrecision extends TweetSafetyLabelType + case object PNegMultimodalHighRecall extends TweetSafetyLabelType + case object BrandSafetyNsfaAggregate extends TweetSafetyLabelType + case object HighCryptospamScore extends TweetSafetyLabelType + case object IpiDevelopmentOnly extends TweetSafetyLabelType + case object BounceEdits extends TweetSafetyLabelType + case object UnsafeUrl extends TweetSafetyLabelType + case object InterstitialDevelopmentOnly extends TweetSafetyLabelType + case object EdiDevelopmentOnly extends TweetSafetyLabelType + case object NsfwTextHighPrecision extends TweetSafetyLabelType + case object HatefulConduct extends TweetSafetyLabelType + case object HatefulConductViolentThreat extends TweetSafetyLabelType + case object NsfaHighPrecision extends TweetSafetyLabelType + case object BrandSafetyExperimental1 extends TweetSafetyLabelType + case object BrandSafetyExperimental2 extends TweetSafetyLabelType + case object BrandSafetyExperimental3 extends TweetSafetyLabelType + case object BrandSafetyExperimental4 extends TweetSafetyLabelType + + case object FosnrHatefulConduct extends TweetSafetyLabelType + case object FosnrHatefulConductLowSeveritySlur extends TweetSafetyLabelType + + case object Deprecated extends TweetSafetyLabelType + case object Unknown extends TweetSafetyLabelType + + def fromThrift(safetyLabelType: s.SafetyLabelType): TweetSafetyLabelType = + thriftToModelMap.get(safetyLabelType) match { + case Some(tweetSafetyLabelType) => tweetSafetyLabelType + case _ => + safetyLabelType match { + case s.SafetyLabelType.EnumUnknownSafetyLabelType(DeprecatedEnumValue) => Deprecated + case _ => + Unknown + } + } + + def toThrift(safetyLabelType: TweetSafetyLabelType): s.SafetyLabelType = { + modelToThriftMap.getOrElse(safetyLabelType, UnknownThriftSafetyLabelType) + } +} + +object TweetSafetyLabel { + def fromThrift(safetyLabelValue: s.SafetyLabelValue): TweetSafetyLabel = + fromTuple(safetyLabelValue.labelType, safetyLabelValue.label) + + def fromTuple( + safetyLabelType: s.SafetyLabelType, + safetyLabel: s.SafetyLabel + ): TweetSafetyLabel = { + TweetSafetyLabel( + labelType = TweetSafetyLabelType.fromThrift(safetyLabelType), + source = safetyLabel.source.flatMap(LabelSource.fromString), + safetyLabelSource = safetyLabel.safetyLabelSource, + applicableUsers = safetyLabel.applicableUsers + .map { perspectivalUsers => + (perspectivalUsers map { + _.userId + }).toSet + }.getOrElse(Set.empty), + score = safetyLabel.score, + modelMetadata = safetyLabel.modelMetadata.flatMap(TweetModelMetadata.fromThrift) + ) + } + + def toThrift(tweetSafetyLabel: TweetSafetyLabel): s.SafetyLabelValue = { + s.SafetyLabelValue( + labelType = TweetSafetyLabelType.toThrift(tweetSafetyLabel.labelType), + label = s.SafetyLabel( + applicableUsers = if (tweetSafetyLabel.applicableUsers.nonEmpty) { + Some(tweetSafetyLabel.applicableUsers.toSeq.map { + s.PerspectivalUser(_) + }) + } else { + None + }, + source = tweetSafetyLabel.source.map(_.name), + score = tweetSafetyLabel.score, + modelMetadata = tweetSafetyLabel.modelMetadata.map(TweetModelMetadata.toThrift) + ) + ) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/UnitOfDiversion.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/UnitOfDiversion.scala new file mode 100644 index 000000000..5d7f1caba --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/UnitOfDiversion.scala @@ -0,0 +1,16 @@ +package com.twitter.visibility.models + +trait UnitOfDiversion { + + def apply: (String, Any) +} + +object UnitOfDiversion { + case class ConversationId(conversationId: Long) extends UnitOfDiversion { + override def apply: (String, Any) = ("conversation_id", conversationId) + } + + case class TweetId(tweetId: Long) extends UnitOfDiversion { + override def apply: (String, Any) = ("tweet_id", tweetId) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/UserAge.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/UserAge.scala new file mode 100644 index 000000000..b2cdbb785 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/UserAge.scala @@ -0,0 +1,15 @@ +package com.twitter.visibility.models + +case class UserAge(ageInYears: Option[Int]) { + def hasAge: Boolean = ageInYears.isDefined + + def isGte(ageToCompare: Int): Boolean = + ageInYears + .collectFirst { + case age if age > ageToCompare => true + }.getOrElse(false) + + def unapply(userAge: UserAge): Option[Int] = { + ageInYears + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/UserLabel.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/UserLabel.scala new file mode 100644 index 000000000..738f73e9e --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/UserLabel.scala @@ -0,0 +1,244 @@ +package com.twitter.visibility.models + +import com.twitter.gizmoduck.{thriftscala => t} +import com.twitter.util.Time +import com.twitter.visibility.util.NamingUtils + +sealed trait UserLabelValue extends SafetyLabelType { + lazy val name: String = NamingUtils.getFriendlyName(this) +} + +case class UserLabel( + id: Long, + createdAt: Time, + createdBy: String, + labelValue: UserLabelValue, + source: Option[LabelSource] = None) + +object UserLabelValue extends SafetyLabelType { + + private lazy val nameToValueMap: Map[String, UserLabelValue] = + List.map(l => l.name.toLowerCase -> l).toMap + def fromName(name: String): Option[UserLabelValue] = nameToValueMap.get(name.toLowerCase) + + private val UnknownThriftUserLabelValue = + t.LabelValue.EnumUnknownLabelValue(UnknownEnumValue) + + private lazy val thriftToModelMap: Map[t.LabelValue, UserLabelValue] = Map( + t.LabelValue.Abusive -> Abusive, + t.LabelValue.AbusiveHighRecall -> AbusiveHighRecall, + t.LabelValue.AgathaSpamTopUser -> AgathaSpamTopUser, + t.LabelValue.BirdwatchDisabled -> BirdwatchDisabled, + t.LabelValue.BlinkBad -> BlinkBad, + t.LabelValue.BlinkQuestionable -> BlinkQuestionable, + t.LabelValue.BlinkWorst -> BlinkWorst, + t.LabelValue.Compromised -> Compromised, + t.LabelValue.DelayedRemediation -> DelayedRemediation, + t.LabelValue.DoNotCharge -> DoNotCharge, + t.LabelValue.DoNotAmplify -> DoNotAmplify, + t.LabelValue.DownrankSpamReply -> DownrankSpamReply, + t.LabelValue.DuplicateContent -> DuplicateContent, + t.LabelValue.EngagementSpammer -> EngagementSpammer, + t.LabelValue.EngagementSpammerHighRecall -> EngagementSpammerHighRecall, + t.LabelValue.ExperimentalPfmUser1 -> ExperimentalPfmUser1, + t.LabelValue.ExperimentalPfmUser2 -> ExperimentalPfmUser2, + t.LabelValue.ExperimentalPfmUser3 -> ExperimentalPfmUser3, + t.LabelValue.ExperimentalPfmUser4 -> ExperimentalPfmUser4, + t.LabelValue.ExperimentalSeh1 -> ExperimentalSeh1, + t.LabelValue.ExperimentalSeh2 -> ExperimentalSeh2, + t.LabelValue.ExperimentalSeh3 -> ExperimentalSeh3, + t.LabelValue.ExperimentalSehUser4 -> ExperimentalSehUser4, + t.LabelValue.ExperimentalSehUser5 -> ExperimentalSehUser5, + t.LabelValue.ExperimentalSensitiveIllegal1 -> ExperimentalSensitiveIllegal1, + t.LabelValue.ExperimentalSensitiveIllegal2 -> ExperimentalSensitiveIllegal2, + t.LabelValue.FakeSignupDeferredRemediation -> FakeSignupDeferredRemediation, + t.LabelValue.FakeSignupHoldback -> FakeSignupHoldback, + t.LabelValue.GoreAndViolenceHighPrecision -> GoreAndViolenceHighPrecision, + t.LabelValue.GoreAndViolenceReportedHeuristics -> GoreAndViolenceReportedHeuristics, + t.LabelValue.HealthExperimentation1 -> HealthExperimentation1, + t.LabelValue.HealthExperimentation2 -> HealthExperimentation2, + t.LabelValue.HighRiskVerification -> HighRiskVerification, + t.LabelValue.LikelyIvs -> LikelyIvs, + t.LabelValue.LiveLowQuality -> LiveLowQuality, + t.LabelValue.LowQuality -> LowQuality, + t.LabelValue.LowQualityHighRecall -> LowQualityHighRecall, + t.LabelValue.NotGraduated -> NotGraduated, + t.LabelValue.NotificationSpamHeuristics -> NotificationSpamHeuristics, + t.LabelValue.NsfwAvatarImage -> NsfwAvatarImage, + t.LabelValue.NsfwBannerImage -> NsfwBannerImage, + t.LabelValue.NsfwHighPrecision -> NsfwHighPrecision, + t.LabelValue.NsfwHighRecall -> NsfwHighRecall, + t.LabelValue.NsfwNearPerfect -> NsfwNearPerfect, + t.LabelValue.NsfwReportedHeuristics -> NsfwReportedHeuristics, + t.LabelValue.NsfwSensitive -> NsfwSensitive, + t.LabelValue.NsfwText -> NsfwText, + t.LabelValue.ReadOnly -> ReadOnly, + t.LabelValue.RecentAbuseStrike -> RecentAbuseStrike, + t.LabelValue.RecentMisinfoStrike -> RecentMisinfoStrike, + t.LabelValue.RecentProfileModification -> RecentProfileModification, + t.LabelValue.RecentSuspension -> RecentSuspension, + t.LabelValue.RecommendationsBlacklist -> RecommendationsBlacklist, + t.LabelValue.SearchBlacklist -> SearchBlacklist, + t.LabelValue.SoftReadOnly -> SoftReadOnly, + t.LabelValue.SpamHighRecall -> SpamHighRecall, + t.LabelValue.SpammyUserModelHighPrecision -> SpammyUserModelHighPrecision, + t.LabelValue.StateMediaAccount -> StateMediaAccount, + t.LabelValue.TsViolation -> TsViolation, + t.LabelValue.UnconfirmedEmailSignup -> UnconfirmedEmailSignup, + t.LabelValue.LegalOpsCase -> LegalOpsCase, + t.LabelValue.AutomationHighRecall -> Deprecated, + t.LabelValue.AutomationHighRecallHoldback -> Deprecated, + t.LabelValue.BouncerUserFiltered -> Deprecated, + t.LabelValue.DeprecatedListBannerPdna -> Deprecated, + t.LabelValue.DeprecatedMigration50 -> Deprecated, + t.LabelValue.DmSpammer -> Deprecated, + t.LabelValue.DuplicateContentHoldback -> Deprecated, + t.LabelValue.FakeAccountExperiment -> Deprecated, + t.LabelValue.FakeAccountReadonly -> Deprecated, + t.LabelValue.FakeAccountRecaptcha -> Deprecated, + t.LabelValue.FakeAccountSspc -> Deprecated, + t.LabelValue.FakeAccountVoiceReadonly -> Deprecated, + t.LabelValue.FakeEngagement -> Deprecated, + t.LabelValue.HasBeenSuspended -> Deprecated, + t.LabelValue.HighProfile -> Deprecated, + t.LabelValue.NotificationsSpike -> Deprecated, + t.LabelValue.NsfaProfileHighRecall -> Deprecated, + t.LabelValue.NsfwUserName -> Deprecated, + t.LabelValue.PotentiallyCompromised -> Deprecated, + t.LabelValue.ProfileAdsBlacklist -> Deprecated, + t.LabelValue.RatelimitDms -> Deprecated, + t.LabelValue.RatelimitFavorites -> Deprecated, + t.LabelValue.RatelimitFollows -> Deprecated, + t.LabelValue.RatelimitRetweets -> Deprecated, + t.LabelValue.RatelimitTweets -> Deprecated, + t.LabelValue.RecentCompromised -> Deprecated, + t.LabelValue.RevenueOnlyHsSignal -> Deprecated, + t.LabelValue.SearchBlacklistHoldback -> Deprecated, + t.LabelValue.SpamHighRecallHoldback -> Deprecated, + t.LabelValue.SpamRepeatOffender -> Deprecated, + t.LabelValue.SpammerExperiment -> Deprecated, + t.LabelValue.TrendBlacklist -> Deprecated, + t.LabelValue.VerifiedDeceptiveIdentity -> Deprecated, + t.LabelValue.BrandSafetyNsfaAggregate -> Deprecated, + t.LabelValue.Pcf -> Deprecated, + t.LabelValue.Reserved97 -> Deprecated, + t.LabelValue.Reserved98 -> Deprecated, + t.LabelValue.Reserved99 -> Deprecated, + t.LabelValue.Reserved100 -> Deprecated, + t.LabelValue.Reserved101 -> Deprecated, + t.LabelValue.Reserved102 -> Deprecated, + t.LabelValue.Reserved103 -> Deprecated, + t.LabelValue.Reserved104 -> Deprecated, + t.LabelValue.Reserved105 -> Deprecated, + t.LabelValue.Reserved106 -> Deprecated + ) + + private lazy val modelToThriftMap: Map[UserLabelValue, t.LabelValue] = + (for ((k, v) <- thriftToModelMap) yield (v, k)) ++ Map( + Deprecated -> t.LabelValue.EnumUnknownLabelValue(DeprecatedEnumValue), + ) + + case object Abusive extends UserLabelValue + case object AbusiveHighRecall extends UserLabelValue + case object AgathaSpamTopUser extends UserLabelValue + case object BirdwatchDisabled extends UserLabelValue + case object BlinkBad extends UserLabelValue + case object BlinkQuestionable extends UserLabelValue + case object BlinkWorst extends UserLabelValue + case object Compromised extends UserLabelValue + case object DelayedRemediation extends UserLabelValue + case object DoNotAmplify extends UserLabelValue + case object DoNotCharge extends UserLabelValue + case object DownrankSpamReply extends UserLabelValue + case object DuplicateContent extends UserLabelValue + case object EngagementSpammer extends UserLabelValue + case object EngagementSpammerHighRecall extends UserLabelValue + case object ExperimentalPfmUser1 extends UserLabelValue + case object ExperimentalPfmUser2 extends UserLabelValue + case object ExperimentalPfmUser3 extends UserLabelValue + case object ExperimentalPfmUser4 extends UserLabelValue + case object ExperimentalSeh1 extends UserLabelValue + case object ExperimentalSeh2 extends UserLabelValue + case object ExperimentalSeh3 extends UserLabelValue + case object ExperimentalSehUser4 extends UserLabelValue + case object ExperimentalSehUser5 extends UserLabelValue + case object ExperimentalSensitiveIllegal1 extends UserLabelValue + case object ExperimentalSensitiveIllegal2 extends UserLabelValue + case object FakeSignupDeferredRemediation extends UserLabelValue + case object FakeSignupHoldback extends UserLabelValue + case object GoreAndViolenceHighPrecision extends UserLabelValue + case object GoreAndViolenceReportedHeuristics extends UserLabelValue + case object HealthExperimentation1 extends UserLabelValue + case object HealthExperimentation2 extends UserLabelValue + case object HighRiskVerification extends UserLabelValue + case object LegalOpsCase extends UserLabelValue + case object LikelyIvs extends UserLabelValue + case object LiveLowQuality extends UserLabelValue + case object LowQuality extends UserLabelValue + case object LowQualityHighRecall extends UserLabelValue + case object NotificationSpamHeuristics extends UserLabelValue + case object NotGraduated extends UserLabelValue + case object NsfwAvatarImage extends UserLabelValue + case object NsfwBannerImage extends UserLabelValue + case object NsfwHighPrecision extends UserLabelValue + case object NsfwHighRecall extends UserLabelValue + case object NsfwNearPerfect extends UserLabelValue + case object NsfwReportedHeuristics extends UserLabelValue + case object NsfwSensitive extends UserLabelValue + case object NsfwText extends UserLabelValue + case object ReadOnly extends UserLabelValue + case object RecentAbuseStrike extends UserLabelValue + case object RecentProfileModification extends UserLabelValue + case object RecentMisinfoStrike extends UserLabelValue + case object RecentSuspension extends UserLabelValue + case object RecommendationsBlacklist extends UserLabelValue + case object SearchBlacklist extends UserLabelValue + case object SoftReadOnly extends UserLabelValue + case object SpamHighRecall extends UserLabelValue + case object SpammyUserModelHighPrecision extends UserLabelValue + case object StateMediaAccount extends UserLabelValue + case object TsViolation extends UserLabelValue + case object UnconfirmedEmailSignup extends UserLabelValue + + case object Deprecated extends UserLabelValue + case object Unknown extends UserLabelValue + + def fromThrift(userLabelValue: t.LabelValue): UserLabelValue = { + thriftToModelMap.get(userLabelValue) match { + case Some(safetyLabelType) => safetyLabelType + case _ => + userLabelValue match { + case t.LabelValue.EnumUnknownLabelValue(DeprecatedEnumValue) => Deprecated + case _ => + Unknown + } + } + } + + def toThrift(userLabelValue: UserLabelValue): t.LabelValue = + modelToThriftMap.get((userLabelValue)).getOrElse(UnknownThriftUserLabelValue) + + val List: List[UserLabelValue] = t.LabelValue.list.map(fromThrift) +} + +object UserLabel { + def fromThrift(userLabel: t.Label): UserLabel = { + UserLabel( + userLabel.id, + Time.fromMilliseconds(userLabel.createdAtMsec), + userLabel.byUser, + UserLabelValue.fromThrift(userLabel.labelValue), + userLabel.source.flatMap(LabelSource.fromString) + ) + } + + def toThrift(userLabel: UserLabel): t.Label = { + t.Label( + userLabel.id, + UserLabelValue.toThrift(userLabel.labelValue), + userLabel.createdAt.inMillis, + byUser = userLabel.createdBy, + source = userLabel.source.map(_.name) + ) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/UserSensitiveMediaSettings.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/UserSensitiveMediaSettings.scala new file mode 100644 index 000000000..85981a007 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/UserSensitiveMediaSettings.scala @@ -0,0 +1,13 @@ +package com.twitter.visibility.models + +import com.twitter.contenthealth.sensitivemediasettings.thriftscala.SensitiveMediaSettings + + +case class UserSensitiveMediaSettings(sensitiveMediaSettings: Option[SensitiveMediaSettings]) { + + def unapply( + userSensitiveMediaSettings: UserSensitiveMediaSettings + ): Option[SensitiveMediaSettings] = { + sensitiveMediaSettings + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/UserUnavailableStateEnum.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/UserUnavailableStateEnum.scala new file mode 100644 index 000000000..492ffac32 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/UserUnavailableStateEnum.scala @@ -0,0 +1,22 @@ +package com.twitter.visibility.models + +import com.twitter.visibility.thriftscala.UserVisibilityResult +import com.twitter.visibility.util.NamingUtils + +sealed trait UserUnavailableStateEnum { + lazy val name: String = NamingUtils.getFriendlyName(this) +} +object UserUnavailableStateEnum { + case object Deleted extends UserUnavailableStateEnum + case object BounceDeleted extends UserUnavailableStateEnum + case object Deactivated extends UserUnavailableStateEnum + case object Offboarded extends UserUnavailableStateEnum + case object Erased extends UserUnavailableStateEnum + case object Suspended extends UserUnavailableStateEnum + case object Protected extends UserUnavailableStateEnum + case object AuthorBlocksViewer extends UserUnavailableStateEnum + case object ViewerBlocksAuthor extends UserUnavailableStateEnum + case object ViewerMutesAuthor extends UserUnavailableStateEnum + case class Filtered(result: UserVisibilityResult) extends UserUnavailableStateEnum + case object Unavailable extends UserUnavailableStateEnum +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/ViewerContext.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/ViewerContext.scala new file mode 100644 index 000000000..4da5210c9 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/ViewerContext.scala @@ -0,0 +1,53 @@ +package com.twitter.visibility.models + +import com.twitter.context.TwitterContext +import com.twitter.context.thriftscala.Viewer +import com.twitter.featureswitches.{UserAgent => FSUserAgent} +import com.twitter.finatra.request.util.AddressUtils + +case class ViewerContext( + userId: Option[Long] = None, + guestId: Option[Long] = None, + userAgentStr: Option[String] = None, + clientApplicationId: Option[Long] = None, + auditIp: String = "0.0.0.0", + requestCountryCode: Option[String] = None, + requestLanguageCode: Option[String] = None, + deviceId: Option[String] = None, + ipTags: Set[String] = Set.empty, + isVerifiedCrawler: Boolean = false, + userRoles: Option[Set[String]] = None) { + val fsUserAgent: Option[FSUserAgent] = userAgentStr.flatMap(ua => FSUserAgent(userAgent = ua)) + + val isTwOffice: Boolean = ipTags.contains(AddressUtils.TwofficeIpTag) +} + +object ViewerContext { + def fromContext: ViewerContext = viewerContext.getOrElse(ViewerContext()) + + def fromContextWithViewerIdFallback(viewerId: Option[Long]): ViewerContext = + viewerContext + .map { viewer => + if (viewer.userId.isEmpty) { + viewer.copy(userId = viewerId) + } else { + viewer + } + }.getOrElse(ViewerContext(viewerId)) + + private def viewerContext: Option[ViewerContext] = + TwitterContext(TwitterContextPermit)().map(apply) + + def apply(viewer: Viewer): ViewerContext = new ViewerContext( + userId = viewer.userId, + guestId = viewer.guestId, + userAgentStr = viewer.userAgent, + clientApplicationId = viewer.clientApplicationId, + auditIp = viewer.auditIp.getOrElse("0.0.0.0"), + requestCountryCode = viewer.requestCountryCode collect { case value => value.toLowerCase }, + requestLanguageCode = viewer.requestLanguageCode collect { case value => value.toLowerCase }, + deviceId = viewer.deviceId, + ipTags = viewer.ipTags.toSet, + isVerifiedCrawler = viewer.isVerifiedCrawler.getOrElse(false) + ) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/ViolationLevel.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/ViolationLevel.scala new file mode 100644 index 000000000..f42faf6c7 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/ViolationLevel.scala @@ -0,0 +1,51 @@ +package com.twitter.visibility.models + +sealed trait ViolationLevel extends Product with Serializable { + val level: Int +} + +object ViolationLevel { + + case object DefaultLevel extends ViolationLevel { + override val level: Int = 0 + } + + case object Level1 extends ViolationLevel { + override val level: Int = 1 + } + + case object Level2 extends ViolationLevel { + override val level: Int = 2 + } + + case object Level3 extends ViolationLevel { + override val level: Int = 3 + } + + case object Level4 extends ViolationLevel { + override val level: Int = 4 + } + + private val safetyLabelToViolationLevel: Map[TweetSafetyLabelType, ViolationLevel] = Map( + TweetSafetyLabelType.FosnrHatefulConduct -> Level3, + TweetSafetyLabelType.FosnrHatefulConductLowSeveritySlur -> Level1, + ) + + val violationLevelToSafetyLabels: Map[ViolationLevel, Set[TweetSafetyLabelType]] = + safetyLabelToViolationLevel.groupBy { case (_, violationLevel) => violationLevel }.map { + case (violationLevel, collection) => (violationLevel, collection.keySet) + } + + def fromTweetSafetyLabel( + tweetSafetyLabel: TweetSafetyLabel + ): ViolationLevel = { + safetyLabelToViolationLevel.getOrElse(tweetSafetyLabel.labelType, DefaultLevel) + } + + def fromTweetSafetyLabelOpt( + tweetSafetyLabel: TweetSafetyLabel + ): Option[ViolationLevel] = { + safetyLabelToViolationLevel.get(tweetSafetyLabel.labelType) + } + +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/models/package.scala b/visibilitylib/src/main/scala/com/twitter/visibility/models/package.scala new file mode 100644 index 000000000..8eb72126d --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/models/package.scala @@ -0,0 +1,5 @@ +package com.twitter.visibility + +package object models { + type CommunityId = Long +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/Action.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/Action.scala new file mode 100644 index 000000000..47446832a --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/Action.scala @@ -0,0 +1,916 @@ +package com.twitter.visibility.rules + +import com.twitter.datatools.entityservice.entities.thriftscala.FleetInterstitial +import com.twitter.scrooge.ThriftStruct +import com.twitter.visibility.common.actions.LocalizedMessage +import com.twitter.visibility.common.actions._ +import com.twitter.visibility.common.actions.converter.scala.AppealableReasonConverter +import com.twitter.visibility.common.actions.converter.scala.AvoidReasonConverter +import com.twitter.visibility.common.actions.converter.scala.ComplianceTweetNoticeEventTypeConverter +import com.twitter.visibility.common.actions.converter.scala.DownrankHomeTimelineReasonConverter +import com.twitter.visibility.common.actions.converter.scala.DropReasonConverter +import com.twitter.visibility.common.actions.converter.scala.InterstitialReasonConverter +import com.twitter.visibility.common.actions.converter.scala.LimitedActionsPolicyConverter +import com.twitter.visibility.common.actions.converter.scala.LimitedEngagementReasonConverter +import com.twitter.visibility.common.actions.converter.scala.LocalizedMessageConverter +import com.twitter.visibility.common.actions.converter.scala.SoftInterventionDisplayTypeConverter +import com.twitter.visibility.common.actions.converter.scala.SoftInterventionReasonConverter +import com.twitter.visibility.common.actions.converter.scala.TombstoneReasonConverter +import com.twitter.visibility.features.Feature +import com.twitter.visibility.logging.thriftscala.HealthActionType +import com.twitter.visibility.models.ViolationLevel +import com.twitter.visibility.strato.thriftscala.NudgeActionType.EnumUnknownNudgeActionType +import com.twitter.visibility.strato.thriftscala.{Nudge => StratoNudge} +import com.twitter.visibility.strato.thriftscala.{NudgeAction => StratoNudgeAction} +import com.twitter.visibility.strato.thriftscala.{NudgeActionType => StratoNudgeActionType} +import com.twitter.visibility.strato.thriftscala.{NudgeActionPayload => StratoNudgeActionPayload} +import com.twitter.visibility.thriftscala +import com.twitter.visibility.util.NamingUtils + +sealed trait Action { + lazy val name: String = NamingUtils.getFriendlyName(this) + lazy val fullName: String = NamingUtils.getFriendlyName(this) + + val severity: Int + def toActionThrift(): thriftscala.Action + + def isComposable: Boolean = false + + def toHealthActionTypeThrift: Option[HealthActionType] +} + +sealed trait Reason { + lazy val name: String = NamingUtils.getFriendlyName(this) +} + +sealed abstract class ActionWithReason(reason: Reason) extends Action { + override lazy val fullName: String = s"${this.name}/${reason.name}" +} + +object Reason { + + case object Bounce extends Reason + + case object ViewerReportedAuthor extends Reason + case object ViewerReportedTweet extends Reason + + case object DeactivatedAuthor extends Reason + case object OffboardedAuthor extends Reason + case object ErasedAuthor extends Reason + case object ProtectedAuthor extends Reason + case object SuspendedAuthor extends Reason + case object ViewerIsUnmentioned extends Reason + + case object Nsfw extends Reason + case object NsfwMedia extends Reason + case object NsfwViewerIsUnderage extends Reason + case object NsfwViewerHasNoStatedAge extends Reason + case object NsfwLoggedOut extends Reason + case object PossiblyUndesirable extends Reason + + case object AbuseEpisodic extends Reason + case object AbuseEpisodicEncourageSelfHarm extends Reason + case object AbuseEpisodicHatefulConduct extends Reason + case object AbuseGlorificationOfViolence extends Reason + case object AbuseGratuitousGore extends Reason + case object AbuseMobHarassment extends Reason + case object AbuseMomentOfDeathOrDeceasedUser extends Reason + case object AbusePrivateInformation extends Reason + case object AbuseRightToPrivacy extends Reason + case object AbuseThreatToExpose extends Reason + case object AbuseViolentSexualConduct extends Reason + case object AbuseViolentThreatHatefulConduct extends Reason + case object AbuseViolentThreatOrBounty extends Reason + + case object MutedKeyword extends Reason + case object Unspecified extends Reason + + case object UntrustedUrl extends Reason + + case object SpamReplyDownRank extends Reason + + case object LowQualityTweet extends Reason + + case object LowQualityMention extends Reason + + case object SpamHighRecallTweet extends Reason + + case object TweetLabelDuplicateContent extends Reason + + case object TweetLabelDuplicateMention extends Reason + + case object PdnaTweet extends Reason + + case object TweetLabeledSpam extends Reason + + case object OneOff extends Reason + case object VotingMisinformation extends Reason + case object HackedMaterials extends Reason + case object Scams extends Reason + case object PlatformManipulation extends Reason + + case object FirstPageSearchResult extends Reason + + case object MisinfoCivic extends Reason + case object MisinfoCrisis extends Reason + case object MisinfoGeneric extends Reason + case object MisinfoMedical extends Reason + case object Misleading extends Reason + case object ExclusiveTweet extends Reason + case object CommunityNotAMember extends Reason + case object CommunityTweetHidden extends Reason + case object CommunityTweetCommunityIsSuspended extends Reason + case object CommunityTweetAuthorRemoved extends Reason + case object InternalPromotedContent extends Reason + case object TrustedFriendsTweet extends Reason + case object Toxicity extends Reason + case object StaleTweet extends Reason + case object DmcaWithheld extends Reason + case object LegalDemandsWithheld extends Reason + case object LocalLawsWithheld extends Reason + case object HatefulConduct extends Reason + case object AbusiveBehavior extends Reason + + case object NotSupportedOnDevice extends Reason + + case object IpiDevelopmentOnly extends Reason + case object InterstitialDevelopmentOnly extends Reason + + case class FosnrReason(appealableReason: AppealableReason) extends Reason + + def toDropReason(reason: Reason): Option[DropReason] = + reason match { + case AuthorBlocksViewer => Some(DropReason.AuthorBlocksViewer) + case CommunityTweetHidden => Some(DropReason.CommunityTweetHidden) + case CommunityTweetCommunityIsSuspended => Some(DropReason.CommunityTweetCommunityIsSuspended) + case DmcaWithheld => Some(DropReason.DmcaWithheld) + case ExclusiveTweet => Some(DropReason.ExclusiveTweet) + case InternalPromotedContent => Some(DropReason.InternalPromotedContent) + case LegalDemandsWithheld => Some(DropReason.LegalDemandsWithheld) + case LocalLawsWithheld => Some(DropReason.LocalLawsWithheld) + case Nsfw => Some(DropReason.NsfwAuthor) + case NsfwLoggedOut => Some(DropReason.NsfwLoggedOut) + case NsfwViewerHasNoStatedAge => Some(DropReason.NsfwViewerHasNoStatedAge) + case NsfwViewerIsUnderage => Some(DropReason.NsfwViewerIsUnderage) + case ProtectedAuthor => Some(DropReason.ProtectedAuthor) + case StaleTweet => Some(DropReason.StaleTweet) + case SuspendedAuthor => Some(DropReason.SuspendedAuthor) + case Unspecified => Some(DropReason.Unspecified) + case ViewerBlocksAuthor => Some(DropReason.ViewerBlocksAuthor) + case ViewerHardMutedAuthor => Some(DropReason.ViewerMutesAuthor) + case ViewerMutesAuthor => Some(DropReason.ViewerMutesAuthor) + case TrustedFriendsTweet => Some(DropReason.TrustedFriendsTweet) + case _ => Some(DropReason.Unspecified) + } + + def fromDropReason(dropReason: DropReason): Reason = + dropReason match { + case DropReason.AuthorBlocksViewer => AuthorBlocksViewer + case DropReason.CommunityTweetHidden => CommunityTweetHidden + case DropReason.CommunityTweetCommunityIsSuspended => CommunityTweetCommunityIsSuspended + case DropReason.DmcaWithheld => DmcaWithheld + case DropReason.ExclusiveTweet => ExclusiveTweet + case DropReason.InternalPromotedContent => InternalPromotedContent + case DropReason.LegalDemandsWithheld => LegalDemandsWithheld + case DropReason.LocalLawsWithheld => LocalLawsWithheld + case DropReason.NsfwAuthor => Nsfw + case DropReason.NsfwLoggedOut => NsfwLoggedOut + case DropReason.NsfwViewerHasNoStatedAge => NsfwViewerHasNoStatedAge + case DropReason.NsfwViewerIsUnderage => NsfwViewerIsUnderage + case DropReason.ProtectedAuthor => ProtectedAuthor + case DropReason.StaleTweet => StaleTweet + case DropReason.SuspendedAuthor => SuspendedAuthor + case DropReason.ViewerBlocksAuthor => ViewerBlocksAuthor + case DropReason.ViewerMutesAuthor => ViewerMutesAuthor + case DropReason.TrustedFriendsTweet => TrustedFriendsTweet + case DropReason.Unspecified => Unspecified + } + + def toAppealableReason(reason: Reason, violationLevel: ViolationLevel): Option[AppealableReason] = + reason match { + case HatefulConduct => Some(AppealableReason.HatefulConduct(violationLevel.level)) + case AbusiveBehavior => Some(AppealableReason.AbusiveBehavior(violationLevel.level)) + case _ => Some(AppealableReason.Unspecified(violationLevel.level)) + } + + def fromAppealableReason(appealableReason: AppealableReason): Reason = + appealableReason match { + case AppealableReason.HatefulConduct(level) => HatefulConduct + case AppealableReason.AbusiveBehavior(level) => AbusiveBehavior + case AppealableReason.Unspecified(level) => Unspecified + } + + def toSoftInterventionReason(appealableReason: AppealableReason): SoftInterventionReason = + appealableReason match { + case AppealableReason.HatefulConduct(level) => + SoftInterventionReason.FosnrReason(appealableReason) + case AppealableReason.AbusiveBehavior(level) => + SoftInterventionReason.FosnrReason(appealableReason) + case AppealableReason.Unspecified(level) => + SoftInterventionReason.FosnrReason(appealableReason) + } + + def toLimitedEngagementReason(appealableReason: AppealableReason): LimitedEngagementReason = + appealableReason match { + case AppealableReason.HatefulConduct(level) => + LimitedEngagementReason.FosnrReason(appealableReason) + case AppealableReason.AbusiveBehavior(level) => + LimitedEngagementReason.FosnrReason(appealableReason) + case AppealableReason.Unspecified(level) => + LimitedEngagementReason.FosnrReason(appealableReason) + } + + val NSFW_MEDIA: Set[Reason] = Set(Nsfw, NsfwMedia) + + def toInterstitialReason(reason: Reason): Option[InterstitialReason] = + reason match { + case r if NSFW_MEDIA.contains(r) => Some(InterstitialReason.ContainsNsfwMedia) + case PossiblyUndesirable => Some(InterstitialReason.PossiblyUndesirable) + case MutedKeyword => Some(InterstitialReason.MatchesMutedKeyword("")) + case ViewerReportedAuthor => Some(InterstitialReason.ViewerReportedAuthor) + case ViewerReportedTweet => Some(InterstitialReason.ViewerReportedTweet) + case ViewerBlocksAuthor => Some(InterstitialReason.ViewerBlocksAuthor) + case ViewerMutesAuthor => Some(InterstitialReason.ViewerMutesAuthor) + case ViewerHardMutedAuthor => Some(InterstitialReason.ViewerMutesAuthor) + case InterstitialDevelopmentOnly => Some(InterstitialReason.DevelopmentOnly) + case DmcaWithheld => Some(InterstitialReason.DmcaWithheld) + case LegalDemandsWithheld => Some(InterstitialReason.LegalDemandsWithheld) + case LocalLawsWithheld => Some(InterstitialReason.LocalLawsWithheld) + case HatefulConduct => Some(InterstitialReason.HatefulConduct) + case AbusiveBehavior => Some(InterstitialReason.AbusiveBehavior) + case FosnrReason(appealableReason) => Some(InterstitialReason.FosnrReason(appealableReason)) + case _ => None + } + + def fromInterstitialReason(interstitialReason: InterstitialReason): Reason = + interstitialReason match { + case InterstitialReason.ContainsNsfwMedia => Reason.NsfwMedia + case InterstitialReason.PossiblyUndesirable => Reason.PossiblyUndesirable + case InterstitialReason.MatchesMutedKeyword(_) => Reason.MutedKeyword + case InterstitialReason.ViewerReportedAuthor => Reason.ViewerReportedAuthor + case InterstitialReason.ViewerReportedTweet => Reason.ViewerReportedTweet + case InterstitialReason.ViewerBlocksAuthor => Reason.ViewerBlocksAuthor + case InterstitialReason.ViewerMutesAuthor => Reason.ViewerMutesAuthor + case InterstitialReason.DevelopmentOnly => Reason.InterstitialDevelopmentOnly + case InterstitialReason.DmcaWithheld => Reason.DmcaWithheld + case InterstitialReason.LegalDemandsWithheld => Reason.LegalDemandsWithheld + case InterstitialReason.LocalLawsWithheld => Reason.LocalLawsWithheld + case InterstitialReason.HatefulConduct => Reason.HatefulConduct + case InterstitialReason.AbusiveBehavior => Reason.AbusiveBehavior + case InterstitialReason.FosnrReason(reason) => Reason.fromAppealableReason(reason) + } + +} + +sealed trait Epitaph { + lazy val name: String = NamingUtils.getFriendlyName(this) +} + +object Epitaph { + + case object Unavailable extends Epitaph + + case object Blocked extends Epitaph + case object BlockedBy extends Epitaph + case object Reported extends Epitaph + + case object BounceDeleted extends Epitaph + case object Deleted extends Epitaph + case object NotFound extends Epitaph + case object PublicInterest extends Epitaph + + case object Bounced extends Epitaph + case object Protected extends Epitaph + case object Suspended extends Epitaph + case object Offboarded extends Epitaph + case object Deactivated extends Epitaph + + case object MutedKeyword extends Epitaph + case object Underage extends Epitaph + case object NoStatedAge extends Epitaph + case object LoggedOutAge extends Epitaph + case object SuperFollowsContent extends Epitaph + + case object Moderated extends Epitaph + case object ForEmergencyUseOnly extends Epitaph + case object UnavailableWithoutLink extends Epitaph + case object CommunityTweetHidden extends Epitaph + case object CommunityTweetMemberRemoved extends Epitaph + case object CommunityTweetCommunityIsSuspended extends Epitaph + + case object UserSuspended extends Epitaph + + case object DevelopmentOnly extends Epitaph + + case object AdultMedia extends Epitaph + case object ViolentMedia extends Epitaph + case object OtherSensitiveMedia extends Epitaph + + case object DmcaWithheldMedia extends Epitaph + case object LegalDemandsWithheldMedia extends Epitaph + case object LocalLawsWithheldMedia extends Epitaph + + case object ToxicReplyFiltered extends Epitaph +} + +sealed trait IsInterstitial { + def toInterstitialThriftWrapper(): thriftscala.AnyInterstitial + def toInterstitialThrift(): ThriftStruct +} + +sealed trait IsAppealable { + def toAppealableThrift(): thriftscala.Appealable +} + +sealed trait IsLimitedEngagements { + def policy: Option[LimitedActionsPolicy] + def getLimitedEngagementReason: LimitedEngagementReason +} + +object IsLimitedEngagements { + def unapply( + ile: IsLimitedEngagements + ): Option[(Option[LimitedActionsPolicy], LimitedEngagementReason)] = { + Some((ile.policy, ile.getLimitedEngagementReason)) + } +} + +sealed abstract class ActionWithEpitaph(epitaph: Epitaph) extends Action { + override lazy val fullName: String = s"${this.name}/${epitaph.name}" +} + +case class Appealable( + reason: Reason, + violationLevel: ViolationLevel, + localizedMessage: Option[LocalizedMessage] = None) + extends ActionWithReason(reason) + with IsAppealable { + + override val severity: Int = 17 + override def toActionThrift(): thriftscala.Action = + thriftscala.Action.Appealable(toAppealableThrift()) + + override def toAppealableThrift(): thriftscala.Appealable = + thriftscala.Appealable( + Reason.toAppealableReason(reason, violationLevel).map(AppealableReasonConverter.toThrift), + localizedMessage.map(LocalizedMessageConverter.toThrift) + ) + + override def toHealthActionTypeThrift: Option[HealthActionType] = Some( + HealthActionType.Appealable) +} + +case class Drop(reason: Reason, applicableCountries: Option[Seq[String]] = None) + extends ActionWithReason(reason) { + + override val severity: Int = 16 + override def toActionThrift(): thriftscala.Action = + thriftscala.Action.Drop( + thriftscala.Drop( + Reason.toDropReason(reason).map(DropReasonConverter.toThrift), + applicableCountries + )) + + override def toHealthActionTypeThrift: Option[HealthActionType] = Some(HealthActionType.Drop) +} + +case class Interstitial( + reason: Reason, + localizedMessage: Option[LocalizedMessage] = None, + applicableCountries: Option[Seq[String]] = None) + extends ActionWithReason(reason) + with IsInterstitial { + + override val severity: Int = 10 + override def toInterstitialThriftWrapper(): thriftscala.AnyInterstitial = + thriftscala.AnyInterstitial.Interstitial( + toInterstitialThrift() + ) + + override def toInterstitialThrift(): thriftscala.Interstitial = + thriftscala.Interstitial( + Reason.toInterstitialReason(reason).map(InterstitialReasonConverter.toThrift), + localizedMessage.map(LocalizedMessageConverter.toThrift) + ) + + override def toActionThrift(): thriftscala.Action = + thriftscala.Action.Interstitial(toInterstitialThrift()) + + def toMediaActionThrift(): thriftscala.MediaAction = + thriftscala.MediaAction.Interstitial(toInterstitialThrift()) + + override def isComposable: Boolean = true + + override def toHealthActionTypeThrift: Option[HealthActionType] = Some( + HealthActionType.TweetInterstitial) +} + +case class InterstitialLimitedEngagements( + reason: Reason, + limitedEngagementReason: Option[LimitedEngagementReason], + localizedMessage: Option[LocalizedMessage] = None, + policy: Option[LimitedActionsPolicy] = None) + extends ActionWithReason(reason) + with IsInterstitial + with IsLimitedEngagements { + + override val severity: Int = 11 + override def toInterstitialThriftWrapper(): thriftscala.AnyInterstitial = + thriftscala.AnyInterstitial.InterstitialLimitedEngagements( + toInterstitialThrift() + ) + + override def toInterstitialThrift(): thriftscala.InterstitialLimitedEngagements = + thriftscala.InterstitialLimitedEngagements( + limitedEngagementReason.map(LimitedEngagementReasonConverter.toThrift), + localizedMessage.map(LocalizedMessageConverter.toThrift) + ) + + override def toActionThrift(): thriftscala.Action = + thriftscala.Action.InterstitialLimitedEngagements(toInterstitialThrift()) + + override def isComposable: Boolean = true + + override def toHealthActionTypeThrift: Option[HealthActionType] = Some( + HealthActionType.LimitedEngagements) + + def getLimitedEngagementReason: LimitedEngagementReason = limitedEngagementReason.getOrElse( + LimitedEngagementReason.NonCompliant + ) +} + +case object Allow extends Action { + + override val severity: Int = -1 + override def toActionThrift(): thriftscala.Action = + thriftscala.Action.Allow(thriftscala.Allow()) + + override def toHealthActionTypeThrift: Option[HealthActionType] = None +} + +case object NotEvaluated extends Action { + + override val severity: Int = -1 + override def toActionThrift(): thriftscala.Action = + thriftscala.Action.NotEvaluated(thriftscala.NotEvaluated()) + + override def toHealthActionTypeThrift: Option[HealthActionType] = None +} + +case class Tombstone(epitaph: Epitaph, applicableCountryCodes: Option[Seq[String]] = None) + extends ActionWithEpitaph(epitaph) { + + override val severity: Int = 15 + override def toActionThrift(): thriftscala.Action = + thriftscala.Action.Tombstone(thriftscala.Tombstone()) + + override def toHealthActionTypeThrift: Option[HealthActionType] = Some(HealthActionType.Tombstone) +} + +case class LocalizedTombstone(reason: TombstoneReason, message: LocalizedMessage) extends Action { + override lazy val fullName: String = s"${this.name}/${NamingUtils.getFriendlyName(reason)}" + + override val severity: Int = 15 + override def toActionThrift(): thriftscala.Action = + thriftscala.Action.Tombstone( + thriftscala.Tombstone( + reason = TombstoneReasonConverter.toThrift(Some(reason)), + message = Some(LocalizedMessageConverter.toThrift(message)) + )) + + override def toHealthActionTypeThrift: Option[HealthActionType] = Some(HealthActionType.Tombstone) +} + +case class DownrankHomeTimeline(reason: Option[DownrankHomeTimelineReason]) extends Action { + + override val severity: Int = 9 + override def toActionThrift(): thriftscala.Action = + thriftscala.Action.DownrankHomeTimeline(toDownrankThrift()) + + def toDownrankThrift(): thriftscala.DownrankHomeTimeline = + thriftscala.DownrankHomeTimeline( + reason.map(DownrankHomeTimelineReasonConverter.toThrift) + ) + + override def isComposable: Boolean = true + + override def toHealthActionTypeThrift: Option[HealthActionType] = Some(HealthActionType.Downrank) +} + +case class Avoid(avoidReason: Option[AvoidReason] = None) extends Action { + + override val severity: Int = 1 + override def toActionThrift(): thriftscala.Action = + thriftscala.Action.Avoid(toAvoidThrift()) + + def toAvoidThrift(): thriftscala.Avoid = + thriftscala.Avoid( + avoidReason.map(AvoidReasonConverter.toThrift) + ) + + override def isComposable: Boolean = true + + override def toHealthActionTypeThrift: Option[HealthActionType] = Some(HealthActionType.Avoid) +} + +case object Downrank extends Action { + + override val severity: Int = 0 + override def toActionThrift(): thriftscala.Action = + thriftscala.Action.Downrank(thriftscala.Downrank()) + + override def toHealthActionTypeThrift: Option[HealthActionType] = Some(HealthActionType.Downrank) +} + +case object ConversationSectionLowQuality extends Action { + + override val severity: Int = 4 + override def toActionThrift(): thriftscala.Action = + thriftscala.Action.ConversationSectionLowQuality(thriftscala.ConversationSectionLowQuality()) + + override def toHealthActionTypeThrift: Option[HealthActionType] = Some( + HealthActionType.ConversationSectionLowQuality) +} + +case object ConversationSectionAbusiveQuality extends Action { + + override val severity: Int = 5 + override def toActionThrift(): thriftscala.Action = + thriftscala.Action.ConversationSectionAbusiveQuality( + thriftscala.ConversationSectionAbusiveQuality()) + + override def toHealthActionTypeThrift: Option[HealthActionType] = Some( + HealthActionType.ConversationSectionAbusiveQuality) + + def toConversationSectionAbusiveQualityThrift(): thriftscala.ConversationSectionAbusiveQuality = + thriftscala.ConversationSectionAbusiveQuality() +} + +case class LimitedEngagements( + reason: LimitedEngagementReason, + policy: Option[LimitedActionsPolicy] = None) + extends Action + with IsLimitedEngagements { + + override val severity: Int = 6 + override def toActionThrift(): thriftscala.Action = + thriftscala.Action.LimitedEngagements(toLimitedEngagementsThrift()) + + def toLimitedEngagementsThrift(): thriftscala.LimitedEngagements = + thriftscala.LimitedEngagements( + Some(LimitedEngagementReasonConverter.toThrift(reason)), + policy.map(LimitedActionsPolicyConverter.toThrift), + Some(reason.toLimitedActionsString) + ) + + override def isComposable: Boolean = true + + override def toHealthActionTypeThrift: Option[HealthActionType] = Some( + HealthActionType.LimitedEngagements) + + def getLimitedEngagementReason: LimitedEngagementReason = reason +} + +case class EmergencyDynamicInterstitial( + copy: String, + linkOpt: Option[String], + localizedMessage: Option[LocalizedMessage] = None, + policy: Option[LimitedActionsPolicy] = None) + extends Action + with IsInterstitial + with IsLimitedEngagements { + + override val severity: Int = 11 + override def toInterstitialThriftWrapper(): thriftscala.AnyInterstitial = + thriftscala.AnyInterstitial.EmergencyDynamicInterstitial( + toInterstitialThrift() + ) + + override def toInterstitialThrift(): thriftscala.EmergencyDynamicInterstitial = + thriftscala.EmergencyDynamicInterstitial( + copy, + linkOpt, + localizedMessage.map(LocalizedMessageConverter.toThrift) + ) + + override def toActionThrift(): thriftscala.Action = + thriftscala.Action.EmergencyDynamicInterstitial(toInterstitialThrift()) + + override def isComposable: Boolean = true + + override def toHealthActionTypeThrift: Option[HealthActionType] = Some( + HealthActionType.TweetInterstitial) + + def getLimitedEngagementReason: LimitedEngagementReason = LimitedEngagementReason.NonCompliant +} + +case class SoftIntervention( + reason: SoftInterventionReason, + engagementNudge: Boolean, + suppressAutoplay: Boolean, + warning: Option[String] = None, + detailsUrl: Option[String] = None, + displayType: Option[SoftInterventionDisplayType] = None, + fleetInterstitial: Option[FleetInterstitial] = None) + extends Action { + + override val severity: Int = 7 + def toSoftInterventionThrift(): thriftscala.SoftIntervention = + thriftscala.SoftIntervention( + Some(SoftInterventionReasonConverter.toThrift(reason)), + engagementNudge = Some(engagementNudge), + suppressAutoplay = Some(suppressAutoplay), + warning = warning, + detailsUrl = detailsUrl, + displayType = SoftInterventionDisplayTypeConverter.toThrift(displayType) + ) + + override def toActionThrift(): thriftscala.Action = + thriftscala.Action.SoftIntervention(toSoftInterventionThrift()) + + override def isComposable: Boolean = true + + override def toHealthActionTypeThrift: Option[HealthActionType] = Some( + HealthActionType.SoftIntervention) +} + +case class TweetInterstitial( + interstitial: Option[IsInterstitial], + softIntervention: Option[SoftIntervention], + limitedEngagements: Option[LimitedEngagements], + downrank: Option[DownrankHomeTimeline], + avoid: Option[Avoid], + mediaInterstitial: Option[Interstitial] = None, + tweetVisibilityNudge: Option[TweetVisibilityNudge] = None, + abusiveQuality: Option[ConversationSectionAbusiveQuality.type] = None, + appealable: Option[Appealable] = None) + extends Action { + + override val severity: Int = 12 + override def toActionThrift(): thriftscala.Action = + thriftscala.Action.TweetInterstitial( + thriftscala.TweetInterstitial( + interstitial.map(_.toInterstitialThriftWrapper()), + softIntervention.map(_.toSoftInterventionThrift()), + limitedEngagements.map(_.toLimitedEngagementsThrift()), + downrank.map(_.toDownrankThrift()), + avoid.map(_.toAvoidThrift()), + mediaInterstitial.map(_.toMediaActionThrift()), + tweetVisibilityNudge.map(_.toTweetVisbilityNudgeThrift()), + abusiveQuality.map(_.toConversationSectionAbusiveQualityThrift()), + appealable.map(_.toAppealableThrift()) + ) + ) + + override def toHealthActionTypeThrift: Option[HealthActionType] = Some( + HealthActionType.TweetInterstitial) +} + +sealed trait LocalizedNudgeActionType +object LocalizedNudgeActionType { + case object Reply extends LocalizedNudgeActionType + case object Retweet extends LocalizedNudgeActionType + case object Like extends LocalizedNudgeActionType + case object Share extends LocalizedNudgeActionType + case object Unspecified extends LocalizedNudgeActionType + + def toThrift( + localizedNudgeActionType: LocalizedNudgeActionType + ): thriftscala.TweetVisibilityNudgeActionType = + localizedNudgeActionType match { + case Reply => thriftscala.TweetVisibilityNudgeActionType.Reply + case Retweet => thriftscala.TweetVisibilityNudgeActionType.Retweet + case Like => thriftscala.TweetVisibilityNudgeActionType.Like + case Share => thriftscala.TweetVisibilityNudgeActionType.Share + case Unspecified => + thriftscala.TweetVisibilityNudgeActionType.EnumUnknownTweetVisibilityNudgeActionType(5) + } + + def fromStratoThrift(stratoNudgeActionType: StratoNudgeActionType): LocalizedNudgeActionType = + stratoNudgeActionType match { + case StratoNudgeActionType.Reply => Reply + case StratoNudgeActionType.Retweet => Retweet + case StratoNudgeActionType.Like => Like + case StratoNudgeActionType.Share => Share + case EnumUnknownNudgeActionType(_) => Unspecified + } +} + +case class LocalizedNudgeActionPayload( + heading: Option[String], + subheading: Option[String], + iconName: Option[String], + ctaTitle: Option[String], + ctaUrl: Option[String], + postCtaText: Option[String]) { + + def toThrift(): thriftscala.TweetVisibilityNudgeActionPayload = { + thriftscala.TweetVisibilityNudgeActionPayload( + heading = heading, + subheading = subheading, + iconName = iconName, + ctaTitle = ctaTitle, + ctaUrl = ctaUrl, + postCtaText = postCtaText + ) + } +} + +object LocalizedNudgeActionPayload { + def fromStratoThrift( + stratoNudgeActionPayload: StratoNudgeActionPayload + ): LocalizedNudgeActionPayload = + LocalizedNudgeActionPayload( + heading = stratoNudgeActionPayload.heading, + subheading = stratoNudgeActionPayload.subheading, + iconName = stratoNudgeActionPayload.iconName, + ctaTitle = stratoNudgeActionPayload.ctaTitle, + ctaUrl = stratoNudgeActionPayload.ctaUrl, + postCtaText = stratoNudgeActionPayload.postCtaText + ) +} + +case class LocalizedNudgeAction( + nudgeActionType: LocalizedNudgeActionType, + nudgeActionPayload: Option[LocalizedNudgeActionPayload]) { + def toThrift(): thriftscala.TweetVisibilityNudgeAction = { + thriftscala.TweetVisibilityNudgeAction( + tweetVisibilitynudgeActionType = LocalizedNudgeActionType.toThrift(nudgeActionType), + tweetVisibilityNudgeActionPayload = nudgeActionPayload.map(_.toThrift) + ) + } +} + +object LocalizedNudgeAction { + def fromStratoThrift(stratoNudgeAction: StratoNudgeAction): LocalizedNudgeAction = + LocalizedNudgeAction( + nudgeActionType = + LocalizedNudgeActionType.fromStratoThrift(stratoNudgeAction.nudgeActionType), + nudgeActionPayload = + stratoNudgeAction.nudgeActionPayload.map(LocalizedNudgeActionPayload.fromStratoThrift) + ) +} + +case class LocalizedNudge(localizedNudgeActions: Seq[LocalizedNudgeAction]) + +case object LocalizedNudge { + def fromStratoThrift(stratoNudge: StratoNudge): LocalizedNudge = + LocalizedNudge(localizedNudgeActions = + stratoNudge.nudgeActions.map(LocalizedNudgeAction.fromStratoThrift)) +} + +case class TweetVisibilityNudge( + reason: TweetVisibilityNudgeReason, + localizedNudge: Option[LocalizedNudge] = None) + extends Action { + + override val severity: Int = 3 + override def toActionThrift(): thriftscala.Action = + thriftscala.Action.TweetVisibilityNudge( + localizedNudge match { + case Some(nudge) => + thriftscala.TweetVisibilityNudge( + tweetVisibilityNudgeActions = Some(nudge.localizedNudgeActions.map(_.toThrift())) + ) + case _ => thriftscala.TweetVisibilityNudge(tweetVisibilityNudgeActions = None) + } + ) + + override def toHealthActionTypeThrift: Option[HealthActionType] = + Some(HealthActionType.TweetVisibilityNudge) + + def toTweetVisbilityNudgeThrift(): thriftscala.TweetVisibilityNudge = + thriftscala.TweetVisibilityNudge(tweetVisibilityNudgeActions = + localizedNudge.map(_.localizedNudgeActions.map(_.toThrift()))) +} + +trait BaseComplianceTweetNotice { + val complianceTweetNoticeEventType: ComplianceTweetNoticeEventType + val details: Option[String] + val extendedDetailsUrl: Option[String] +} + +case class ComplianceTweetNoticePreEnrichment( + reason: Reason, + complianceTweetNoticeEventType: ComplianceTweetNoticeEventType, + details: Option[String] = None, + extendedDetailsUrl: Option[String] = None) + extends Action + with BaseComplianceTweetNotice { + + override val severity: Int = 2 + def toComplianceTweetNoticeThrift(): thriftscala.ComplianceTweetNotice = + thriftscala.ComplianceTweetNotice( + ComplianceTweetNoticeEventTypeConverter.toThrift(complianceTweetNoticeEventType), + ComplianceTweetNoticeEventTypeConverter.eventTypeToLabelTitle(complianceTweetNoticeEventType), + details, + extendedDetailsUrl + ) + + override def toActionThrift(): thriftscala.Action = + thriftscala.Action.ComplianceTweetNotice( + toComplianceTweetNoticeThrift() + ) + + override def toHealthActionTypeThrift: Option[HealthActionType] = None + + def toComplianceTweetNotice(): ComplianceTweetNotice = { + ComplianceTweetNotice( + complianceTweetNoticeEventType = complianceTweetNoticeEventType, + labelTitle = ComplianceTweetNoticeEventTypeConverter.eventTypeToLabelTitle( + complianceTweetNoticeEventType), + details = details, + extendedDetailsUrl = extendedDetailsUrl + ) + } +} + +case class ComplianceTweetNotice( + complianceTweetNoticeEventType: ComplianceTweetNoticeEventType, + labelTitle: Option[String] = None, + details: Option[String] = None, + extendedDetailsUrl: Option[String] = None) + extends Action + with BaseComplianceTweetNotice { + + override val severity: Int = 2 + def toComplianceTweetNoticeThrift(): thriftscala.ComplianceTweetNotice = + thriftscala.ComplianceTweetNotice( + ComplianceTweetNoticeEventTypeConverter.toThrift(complianceTweetNoticeEventType), + labelTitle, + details, + extendedDetailsUrl + ) + + override def toActionThrift(): thriftscala.Action = + thriftscala.Action.ComplianceTweetNotice( + toComplianceTweetNoticeThrift() + ) + + override def toHealthActionTypeThrift: Option[HealthActionType] = None +} + +object Action { + def toThrift[T <: Action](action: T): thriftscala.Action = + action.toActionThrift() + + def getFirstInterstitial(actions: Action*): Option[IsInterstitial] = + actions.collectFirst { + case ile: InterstitialLimitedEngagements => ile + case edi: EmergencyDynamicInterstitial => edi + case i: Interstitial => i + } + + def getFirstSoftIntervention(actions: Action*): Option[SoftIntervention] = + actions.collectFirst { + case si: SoftIntervention => si + } + + def getFirstLimitedEngagements(actions: Action*): Option[LimitedEngagements] = + actions.collectFirst { + case le: LimitedEngagements => le + } + + def getAllLimitedEngagements(actions: Action*): Seq[IsLimitedEngagements] = + actions.collect { + case ile: IsLimitedEngagements => ile + } + + def getFirstDownrankHomeTimeline(actions: Action*): Option[DownrankHomeTimeline] = + actions.collectFirst { + case dr: DownrankHomeTimeline => dr + } + + def getFirstAvoid(actions: Action*): Option[Avoid] = + actions.collectFirst { + case a: Avoid => a + } + + def getFirstMediaInterstitial(actions: Action*): Option[Interstitial] = + actions.collectFirst { + case i: Interstitial if Reason.NSFW_MEDIA.contains(i.reason) => i + } + + def getFirstTweetVisibilityNudge(actions: Action*): Option[TweetVisibilityNudge] = + actions.collectFirst { + case n: TweetVisibilityNudge => n + } +} + +sealed trait State { + lazy val name: String = NamingUtils.getFriendlyName(this) +} + +object State { + case object Pending extends State + case object Disabled extends State + final case class MissingFeature(features: Set[Feature[_]]) extends State + final case class FeatureFailed(features: Map[Feature[_], Throwable]) extends State + final case class RuleFailed(throwable: Throwable) extends State + case object Skipped extends State + case object ShortCircuited extends State + case object Heldback extends State + case object Evaluated extends State +} + +case class RuleResult(action: Action, state: State) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/AdvancedFilteringRules.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/AdvancedFilteringRules.scala new file mode 100644 index 000000000..f3e3da102 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/AdvancedFilteringRules.scala @@ -0,0 +1,71 @@ +package com.twitter.visibility.rules + +import com.twitter.gizmoduck.thriftscala.MentionFilter.Following +import com.twitter.visibility.features.ViewerMentionFilter +import com.twitter.visibility.rules.Condition._ +import com.twitter.visibility.rules.Reason.Unspecified + +object NoConfirmedEmailRule + extends RuleWithConstantAction( + Drop(Unspecified), + And( + NonAuthorViewer, + Not(ViewerDoesFollowAuthor), + ViewerFiltersNoConfirmedEmail, + Not(AuthorHasConfirmedEmail) + ) + ) + +object NoConfirmedPhoneRule + extends RuleWithConstantAction( + Drop(Unspecified), + And( + NonAuthorViewer, + Not(ViewerDoesFollowAuthor), + ViewerFiltersNoConfirmedPhone, + Not(AuthorHasVerifiedPhone) + ) + ) + +object NoDefaultProfileImageRule + extends RuleWithConstantAction( + Drop(Unspecified), + And( + NonAuthorViewer, + Not(ViewerDoesFollowAuthor), + ViewerFiltersDefaultProfileImage, + AuthorHasDefaultProfileImage + ) + ) + +object NoNewUsersRule + extends RuleWithConstantAction( + Drop(Unspecified), + And( + NonAuthorViewer, + Not(ViewerDoesFollowAuthor), + AuthorIsNewAccount + ) + ) + +object NoNotFollowedByRule + extends RuleWithConstantAction( + Drop(Unspecified), + And( + NonAuthorViewer, + Not(ViewerDoesFollowAuthor), + ViewerFiltersNotFollowedBy, + Not(AuthorDoesFollowViewer) + ) + ) + +object OnlyPeopleIFollowRule + extends RuleWithConstantAction( + Drop(Unspecified), + And( + NonAuthorViewer, + Not(ViewerDoesFollowAuthor), + Equals(ViewerMentionFilter, Following), + Not(NotificationIsOnCommunityTweet) + ) + ) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/rules/BUILD new file mode 100644 index 000000000..a8c1c0ac1 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/BUILD @@ -0,0 +1,37 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "3rdparty/jvm/com/google/guava", + "3rdparty/jvm/com/squareup/okhttp:okhttp3", + "abdecider/src/main/scala", + "configapi/configapi-core", + "decider/src/main/scala", + "health-platform-manipulation/src/main/scala/com/twitter/health/platform_manipulation/stcm_tweet_holdback", + "scribelib/marshallers/src/main/scala/com/twitter/scribelib/marshallers", + "servo/decider/src/main/scala", + "snowflake/src/main/scala/com/twitter/snowflake/id", + "src/scala/com/twitter/takedown/util", + "src/thrift/com/twitter/content-health/sensitivemediasettings:sensitivemediasettings-scala", + "src/thrift/com/twitter/gizmoduck:user-thrift-scala", + "src/thrift/com/twitter/search/common:constants-scala", + "src/thrift/com/twitter/spam/rtf:safety-level-scala", + "src/thrift/com/twitter/spam/rtf:safety-result-scala", + "src/thrift/com/twitter/tweetypie:tweet-scala", + "stitch/stitch-core/src/main/scala/com/twitter/stitch", + "visibility/common/src/main/scala/com/twitter/visibility/common/actions", + "visibility/common/src/main/scala/com/twitter/visibility/common/actions/converter/scala", + "visibility/common/src/main/thrift/com/twitter/visibility:action-scala", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/configs", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/params", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/models", + "visibility/lib/src/main/scala/com/twitter/visibility/util", + "visibility/lib/src/main/thrift/com/twitter/visibility/logging:vf-logging-scala", + "visibility/lib/src/main/thrift/com/twitter/visibility/strato:vf-strato-scala", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/CardRules.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/CardRules.scala new file mode 100644 index 000000000..695d40ad4 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/CardRules.scala @@ -0,0 +1,52 @@ +package com.twitter.visibility.rules + +import com.twitter.visibility.configapi.params.FSRuleParams.CardUriRootDomainDenyListParam +import com.twitter.visibility.configapi.params.RuleParam +import com.twitter.visibility.configapi.params.RuleParams.EnableCardUriRootDomainCardDenylistRule +import com.twitter.visibility.configapi.params.RuleParams.EnableCommunityNonMemberPollCardRule +import com.twitter.visibility.configapi.params.RuleParams.EnableCommunityNonMemberPollCardRuleFailClosed +import com.twitter.visibility.rules.Condition.And +import com.twitter.visibility.rules.Condition.CardUriHasRootDomain +import com.twitter.visibility.rules.Condition.CommunityTweetCommunityVisible +import com.twitter.visibility.rules.Condition.IsPollCard +import com.twitter.visibility.rules.Condition.LoggedOutOrViewerNotFollowingAuthor +import com.twitter.visibility.rules.Condition.Not +import com.twitter.visibility.rules.Condition.Or +import com.twitter.visibility.rules.Condition.ProtectedAuthor +import com.twitter.visibility.rules.Condition.TweetIsCommunityTweet +import com.twitter.visibility.rules.Condition.ViewerIsCommunityMember + +object DropProtectedAuthorPollCardRule + extends RuleWithConstantAction( + Drop(Reason.ProtectedAuthor), + And( + IsPollCard, + ProtectedAuthor, + LoggedOutOrViewerNotFollowingAuthor, + ) + ) + +object DropCardUriRootDomainDenylistRule + extends RuleWithConstantAction( + Drop(Reason.Unspecified), + And(CardUriHasRootDomain(CardUriRootDomainDenyListParam)) + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableCardUriRootDomainCardDenylistRule) +} + +object DropCommunityNonMemberPollCardRule + extends RuleWithConstantAction( + Drop(Reason.CommunityNotAMember), + And( + IsPollCard, + TweetIsCommunityTweet, + Or( + Not(ViewerIsCommunityMember), + Not(CommunityTweetCommunityVisible), + ) + ), + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableCommunityNonMemberPollCardRule) + override def enableFailClosed: Seq[RuleParam[Boolean]] = Seq( + EnableCommunityNonMemberPollCardRuleFailClosed) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/ComposableActions.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/ComposableActions.scala new file mode 100644 index 000000000..36b647618 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/ComposableActions.scala @@ -0,0 +1,45 @@ +package com.twitter.visibility.rules + +object ComposableActions { + + object ComposableActionsWithConversationSectionAbusiveQuality { + def unapply( + composableActions: TweetInterstitial + ): Option[ConversationSectionAbusiveQuality.type] = { + composableActions.abusiveQuality + } + } + + object ComposableActionsWithSoftIntervention { + def unapply(composableActions: TweetInterstitial): Option[SoftIntervention] = { + composableActions.softIntervention match { + case Some(si: SoftIntervention) => Some(si) + case _ => None + } + } + } + + object ComposableActionsWithInterstitialLimitedEngagements { + def unapply(composableActions: TweetInterstitial): Option[InterstitialLimitedEngagements] = { + composableActions.interstitial match { + case Some(ile: InterstitialLimitedEngagements) => Some(ile) + case _ => None + } + } + } + + object ComposableActionsWithInterstitial { + def unapply(composableActions: TweetInterstitial): Option[Interstitial] = { + composableActions.interstitial match { + case Some(i: Interstitial) => Some(i) + case _ => None + } + } + } + + object ComposableActionsWithAppealable { + def unapply(composableActions: TweetInterstitial): Option[Appealable] = { + composableActions.appealable + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/Condition.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/Condition.scala new file mode 100644 index 000000000..7d2dcde3c --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/Condition.scala @@ -0,0 +1,2401 @@ +package com.twitter.visibility.rules + +import com.twitter.contenthealth.sensitivemediasettings.thriftscala.SensitiveMediaSettingsLevel +import com.twitter.contenthealth.toxicreplyfilter.thriftscala.FilterState +import com.twitter.conversions.DurationOps._ +import com.twitter.gizmoduck.thriftscala.Label +import com.twitter.gizmoduck.thriftscala.MuteSurface +import com.twitter.health.platform_manipulation.stcm_tweet_holdback.StcmTweetHoldback +import com.twitter.search.common.constants.thriftscala.ThriftQuerySource +import com.twitter.snowflake.id.SnowflakeId +import com.twitter.takedown.util.TakedownReasons +import com.twitter.takedown.util.{TakedownReasons => TakedownReasonsUtil} +import com.twitter.timelines.configapi.EnumParam +import com.twitter.timelines.configapi.Param +import com.twitter.timelines.configapi.Params +import com.twitter.tseng.withholding.thriftscala.TakedownReason +import com.twitter.util.Duration +import com.twitter.util.Time +import com.twitter.visibility.configapi.params.RuleParam +import com.twitter.visibility.configapi.params.RuleParams +import com.twitter.visibility.features.AuthorIsSuspended +import com.twitter.visibility.features.CardIsPoll +import com.twitter.visibility.features.CardUriHost +import com.twitter.visibility.features.SearchQuerySource +import com.twitter.visibility.features._ +import com.twitter.visibility.features.{AuthorBlocksOuterAuthor => AuthorBlocksOuterAuthorFeature} +import com.twitter.visibility.features.{AuthorBlocksViewer => AuthorBlocksViewerFeature} +import com.twitter.visibility.features.{ + CommunityTweetAuthorIsRemoved => CommunityTweetAuthorIsRemovedFeature +} +import com.twitter.visibility.features.{ + CommunityTweetCommunityNotFound => CommunityTweetCommunityNotFoundFeature +} +import com.twitter.visibility.features.{ + CommunityTweetCommunityDeleted => CommunityTweetCommunityDeletedFeature +} +import com.twitter.visibility.features.{ + CommunityTweetCommunitySuspended => CommunityTweetCommunitySuspendedFeature +} +import com.twitter.visibility.features.{ + CommunityTweetCommunityVisible => CommunityTweetCommunityVisibleFeature +} +import com.twitter.visibility.features.{CommunityTweetIsHidden => CommunityTweetIsHiddenFeature} +import com.twitter.visibility.features.{ + NotificationIsOnCommunityTweet => NotificationIsOnCommunityTweetFeature +} +import com.twitter.visibility.features.{OuterAuthorFollowsAuthor => OuterAuthorFollowsAuthorFeature} +import com.twitter.visibility.features.{OuterAuthorIsInnerAuthor => OuterAuthorIsInnerAuthorFeature} +import com.twitter.visibility.features.{TweetHasCard => TweetHasCardFeature} +import com.twitter.visibility.features.{TweetHasMedia => TweetHasMediaFeature} +import com.twitter.visibility.features.{TweetIsCommunityTweet => TweetIsCommunityTweetFeature} +import com.twitter.visibility.features.{TweetIsEditTweet => TweetIsEditTweetFeature} +import com.twitter.visibility.features.{TweetIsStaleTweet => TweetIsStaleTweetFeature} +import com.twitter.visibility.features.{ViewerBlocksAuthor => ViewerBlocksAuthorFeature} +import com.twitter.visibility.features.{ViewerIsCommunityAdmin => ViewerIsCommunityAdminFeature} +import com.twitter.visibility.features.{ViewerIsCommunityMember => ViewerIsCommunityMemberFeature} +import com.twitter.visibility.features.{ + ViewerIsCommunityModerator => ViewerIsCommunityModeratorFeature +} +import com.twitter.visibility.features.{ + ViewerIsInternalCommunitiesAdmin => ViewerIsInternalCommunitiesAdminFeature +} +import com.twitter.visibility.features.{ViewerMutesAuthor => ViewerMutesAuthorFeature} +import com.twitter.visibility.features.{ + ViewerMutesRetweetsFromAuthor => ViewerMutesRetweetsFromAuthorFeature +} +import com.twitter.visibility.models.ViolationLevel +import com.twitter.visibility.models._ +import com.twitter.visibility.rules.Result.FoundCardUriRootDomain +import com.twitter.visibility.rules.Result.FoundMediaLabel +import com.twitter.visibility.rules.Result.FoundSpaceLabel +import com.twitter.visibility.rules.Result.FoundSpaceLabelWithScoreAboveThreshold +import com.twitter.visibility.rules.Result.FoundTweetLabel +import com.twitter.visibility.rules.Result.FoundTweetLabelForPerspectivalUser +import com.twitter.visibility.rules.Result.FoundTweetLabelWithLanguageIn +import com.twitter.visibility.rules.Result.FoundTweetLabelWithLanguageScoreAboveThreshold +import com.twitter.visibility.rules.Result.FoundTweetLabelWithScoreAboveThreshold +import com.twitter.visibility.rules.Result.FoundTweetViolationOfLevel +import com.twitter.visibility.rules.Result.FoundTweetViolationOfSomeLevel +import com.twitter.visibility.rules.Result.FoundUserLabel +import com.twitter.visibility.rules.Result.FoundUserRole +import com.twitter.visibility.rules.Result.HasQuerySource +import com.twitter.visibility.rules.Result.HasTweetTimestampAfterCutoff +import com.twitter.visibility.rules.Result.HasTweetTimestampAfterOffset +import com.twitter.visibility.rules.Result.HasTweetTimestampBeforeCutoff +import com.twitter.visibility.rules.Result.ParamWasTrue +import com.twitter.visibility.rules.Result.Result +import com.twitter.visibility.rules.Result.Satisfied +import com.twitter.visibility.rules.Result.Unsatisfied +import com.twitter.visibility.util.NamingUtils +import com.twitter.visibility.{features => feats} + +sealed trait PreFilterResult +case object Filtered extends PreFilterResult +case object NeedsFullEvaluation extends PreFilterResult +case object NotFiltered extends PreFilterResult + +sealed trait Condition { + lazy val name: String = NamingUtils.getFriendlyName(this) + def features: Set[Feature[_]] + def optionalFeatures: Set[Feature[_]] + + def preFilter( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): PreFilterResult = { + if (features.forall(featureMap.contains)) { + if (apply(evaluationContext, featureMap).asBoolean) { + NotFiltered + } else { + Filtered + } + } else { + NeedsFullEvaluation + } + } + + def apply(evaluationContext: EvaluationContext, featureMap: Map[Feature[_], _]): Result +} + +trait PreFilterOnOptionalFeatures extends Condition { + override def preFilter( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): PreFilterResult = + if ((features ++ optionalFeatures).forall(featureMap.contains)) { + if (apply(evaluationContext, featureMap).asBoolean) { + NotFiltered + } else { + Filtered + } + } else { + NeedsFullEvaluation + } +} + +trait HasSafetyLabelType { + val labelTypes: Set[SafetyLabelType] + def hasLabelType(labelType: SafetyLabelType): Boolean = labelTypes.contains(labelType) +} + +sealed trait HasNestedConditions extends HasSafetyLabelType { + val conditions: Seq[Condition] + override lazy val labelTypes: Set[SafetyLabelType] = conditions + .collect { + case lt: HasSafetyLabelType => lt.labelTypes + }.flatten.toSet +} + +object Result { + sealed trait ConditionReason + case object FoundInnerQuotedTweet extends ConditionReason + case object FoundTweetViolationOfSomeLevel extends ConditionReason + case class FoundTweetViolationOfLevel(level: ViolationLevel) extends ConditionReason + case class FoundTweetLabel(label: TweetSafetyLabelType) extends ConditionReason + case class FoundSpaceLabel(label: SpaceSafetyLabelType) extends ConditionReason + case class FoundMediaLabel(label: MediaSafetyLabelType) extends ConditionReason + case class FoundTweetLabelForPerspectivalUser(label: TweetSafetyLabelType) extends ConditionReason + case class FoundTweetLabelWithLanguageScoreAboveThreshold( + label: TweetSafetyLabelType, + languagesToScoreThresholds: Map[String, Double]) + extends ConditionReason + case class FoundTweetLabelWithScoreAboveThreshold(label: TweetSafetyLabelType, threshold: Double) + extends ConditionReason + case class FoundTweetLabelWithLanguageIn( + safetyLabelType: TweetSafetyLabelType, + languages: Set[String]) + extends ConditionReason + case class FoundTweetSafetyLabelWithPredicate(safetyLabelType: TweetSafetyLabelType, name: String) + extends ConditionReason + case class FoundUserLabel(label: UserLabelValue) extends ConditionReason + case class FoundMutedKeyword(keyword: String) extends ConditionReason + case object HasTweetTimestampAfterCutoff extends ConditionReason + case object HasTweetTimestampAfterOffset extends ConditionReason + case object HasTweetTimestampBeforeCutoff extends ConditionReason + case class IsTweetReplyToParentTweetBeforeDuration(duration: Duration) extends ConditionReason + case class IsTweetReplyToRootTweetBeforeDuration(duration: Duration) extends ConditionReason + case class HasQuerySource(querySource: ThriftQuerySource) extends ConditionReason + case class FoundUserRole(role: String) extends ConditionReason + case class ViewerInHrcj(jurisdiction: String) extends ConditionReason + case class ViewerOrRequestInCountry(country: String) extends ConditionReason + case class ViewerAgeInYears(ageInYears: Int) extends ConditionReason + case object NoViewerAge extends ConditionReason + case class ParamWasTrue(param: Param[Boolean]) extends ConditionReason + case class FoundCardUriRootDomain(domain: String) extends ConditionReason + case object Unknown extends ConditionReason + + sealed trait Result { + def asBoolean: Boolean + } + + val SatisfiedResult: Result = Satisfied() + + case class Satisfied(reason: ConditionReason = Unknown) extends Result { + override val asBoolean: Boolean = true + } + + case class Unsatisfied(condition: Condition) extends Result { + override val asBoolean: Boolean = false + } + + def fromMutedKeyword(mutedKeyword: MutedKeyword, unsatisfied: Unsatisfied): Result = { + mutedKeyword match { + case MutedKeyword(Some(keyword)) => Satisfied(FoundMutedKeyword(keyword)) + case _ => unsatisfied + } + } + + case class FoundSpaceLabelWithScoreAboveThreshold(label: SpaceSafetyLabelType, threshold: Double) + extends ConditionReason +} + +object Condition { + + abstract class BooleanFeatureCondition(feature: Feature[Boolean]) extends Condition { + override val features: Set[Feature[_]] = Set(feature) + override val optionalFeatures: Set[Feature[_]] = Set.empty + + private val UnsatisfiedResult = Unsatisfied(this) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = + if (featureMap(feature).asInstanceOf[Boolean]) { + Result.SatisfiedResult + } else { + UnsatisfiedResult + } + } + + case class ParamIsTrue(param: Param[Boolean]) extends Condition with HasParams { + override lazy val name: String = s"ParamIsTrue(${NamingUtils.getFriendlyName(param)})" + override val features: Set[Feature[_]] = Set.empty + override val optionalFeatures: Set[Feature[_]] = Set.empty + + private val UnsatisfiedResult = Unsatisfied(this) + private val SatisfiedResult = Satisfied(ParamWasTrue(param)) + + override val params: Set[Param[_]] = Set(param) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = + if (evaluationContext.params(param)) { + SatisfiedResult + } else { + UnsatisfiedResult + } + } + + case object Never extends Condition { + override lazy val name: String = s"""Never""" + override val features: Set[Feature[_]] = Set.empty + override val optionalFeatures: Set[Feature[_]] = Set.empty + private val UnsatisfiedResult = Unsatisfied(this) + + override def preFilter( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): PreFilterResult = { + NeedsFullEvaluation + } + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = + UnsatisfiedResult + } + + class BooleanCondition(value: Boolean) extends Condition { + override lazy val name: String = s"""${if (value) "True" else "False"}""" + override val features: Set[Feature[_]] = Set.empty + override val optionalFeatures: Set[Feature[_]] = Set.empty + private val UnsatisfiedResult = Unsatisfied(this) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = + value match { + case true => Result.SatisfiedResult + case false => UnsatisfiedResult + } + } + + case object True extends BooleanCondition(true) + case object False extends BooleanCondition(false) + + abstract class ContentTakendownInViewerCountry(takedownFeature: Feature[Seq[TakedownReason]]) + extends Condition { + override val features: Set[Feature[_]] = Set(takedownFeature) + override val optionalFeatures: Set[Feature[_]] = Set(RequestCountryCode) + + private val UnsatisfiedResult = Unsatisfied(this) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val requestCountryCode = featureMap.get(RequestCountryCode).asInstanceOf[Option[String]] + val takedownReasons = featureMap(takedownFeature).asInstanceOf[Seq[TakedownReason]] + if (TakedownReasonsUtil.isTakenDownIn(requestCountryCode, takedownReasons)) { + Result.SatisfiedResult + } else { + UnsatisfiedResult + } + } + } + + case object TweetTakendownInViewerCountry + extends ContentTakendownInViewerCountry(TweetTakedownReasons) + + case object AuthorTakendownInViewerCountry + extends ContentTakendownInViewerCountry(AuthorTakedownReasons) + + case object SuspendedAuthor extends BooleanFeatureCondition(AuthorIsSuspended) + + case object SuspendedViewer extends BooleanFeatureCondition(ViewerIsSuspended) + + case object DeactivatedViewer extends BooleanFeatureCondition(ViewerIsDeactivated) + + case object UnavailableAuthor extends BooleanFeatureCondition(AuthorIsUnavailable) + + case object IsVerifiedCrawlerViewer extends BooleanFeatureCondition(RequestIsVerifiedCrawler) + + case object LoggedOutViewer extends Condition { + override val features: Set[Feature[_]] = Set.empty + override val optionalFeatures: Set[Feature[_]] = Set(ViewerId) + + private val UnsatisfiedResult = Unsatisfied(this) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = + if (featureMap.contains(ViewerId)) UnsatisfiedResult else Result.SatisfiedResult + } + + case object IsSelfQuote extends Condition { + override val features: Set[Feature[_]] = Set(AuthorId, OuterAuthorId) + override val optionalFeatures: Set[Feature[_]] = Set.empty + + private val UnsatisfiedResult = Unsatisfied(this) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val authorIds = featureMap(AuthorId).asInstanceOf[Set[Long]] + val outerAuthorId = featureMap(OuterAuthorId).asInstanceOf[Long] + if (authorIds.contains(outerAuthorId)) { + Result.SatisfiedResult + } else { + UnsatisfiedResult + } + } + } + + case object ViewerIsAuthor extends Condition { + override val features: Set[Feature[_]] = Set(AuthorId) + override val optionalFeatures: Set[Feature[_]] = Set(ViewerId) + + private val UnsatisfiedResult = Unsatisfied(this) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = + if (featureMap.contains(ViewerId)) { + val authorIds = featureMap(AuthorId).asInstanceOf[Set[Long]] + val viewerId = featureMap(ViewerId).asInstanceOf[Long] + if (authorIds.contains(viewerId)) { + Result.SatisfiedResult + } else { + UnsatisfiedResult + } + } else { + UnsatisfiedResult + } + } + + case object NonAuthorViewer extends Condition { + override val features: Set[Feature[_]] = Set(AuthorId) + override val optionalFeatures: Set[Feature[_]] = Set(ViewerId) + + private val UnsatisfiedResult = Unsatisfied(this) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = + if (featureMap.contains(ViewerId)) { + val authorIds = featureMap(AuthorId).asInstanceOf[Set[Long]] + val viewerId = featureMap(ViewerId).asInstanceOf[Long] + if (authorIds.contains(viewerId)) { + UnsatisfiedResult + } else { + Result.SatisfiedResult + } + } else { + Result.SatisfiedResult + } + } + + case object ViewerFollowsAuthorOfFosnrViolatingTweet + extends BooleanFeatureCondition(ViewerFollowsAuthorOfViolatingTweet) + + case object ViewerDoesNotFollowAuthorOfFosnrViolatingTweet + extends BooleanFeatureCondition(ViewerDoesNotFollowAuthorOfViolatingTweet) + + case object ViewerDoesFollowAuthor extends BooleanFeatureCondition(ViewerFollowsAuthor) + + case object AuthorDoesFollowViewer extends BooleanFeatureCondition(AuthorFollowsViewer) + + case object AuthorBlocksViewer extends BooleanFeatureCondition(AuthorBlocksViewerFeature) + + case object ViewerBlocksAuthor extends BooleanFeatureCondition(ViewerBlocksAuthorFeature) + + case object ViewerIsUnmentioned extends BooleanFeatureCondition(NotificationIsOnUnmentionedViewer) + + case object AuthorBlocksOuterAuthor + extends BooleanFeatureCondition(AuthorBlocksOuterAuthorFeature) + + case object OuterAuthorFollowsAuthor + extends BooleanFeatureCondition(OuterAuthorFollowsAuthorFeature) + + case object OuterAuthorIsInnerAuthor + extends BooleanFeatureCondition(OuterAuthorIsInnerAuthorFeature) + + case object ViewerMutesAuthor extends BooleanFeatureCondition(ViewerMutesAuthorFeature) + + case object ViewerReportsAuthor extends BooleanFeatureCondition(ViewerReportsAuthorAsSpam) + case object ViewerReportsTweet extends BooleanFeatureCondition(ViewerReportedTweet) + + case object IsQuotedInnerTweet extends BooleanFeatureCondition(TweetIsInnerQuotedTweet) + + case object IsSourceTweet extends BooleanFeatureCondition(TweetIsSourceTweet) + + case object ViewerMutesRetweetsFromAuthor + extends BooleanFeatureCondition(ViewerMutesRetweetsFromAuthorFeature) + + case object ConversationRootAuthorDoesFollowViewer + extends BooleanFeatureCondition(ConversationRootAuthorFollowsViewer) + + case object ViewerDoesFollowConversationRootAuthor + extends BooleanFeatureCondition(ViewerFollowsConversationRootAuthor) + + case object TweetIsCommunityTweet extends BooleanFeatureCondition(TweetIsCommunityTweetFeature) + + case object NotificationIsOnCommunityTweet + extends BooleanFeatureCondition(NotificationIsOnCommunityTweetFeature) + + sealed trait CommunityTweetCommunityUnavailable extends Condition + + case object CommunityTweetCommunityNotFound + extends BooleanFeatureCondition(CommunityTweetCommunityNotFoundFeature) + with CommunityTweetCommunityUnavailable + + case object CommunityTweetCommunityDeleted + extends BooleanFeatureCondition(CommunityTweetCommunityDeletedFeature) + with CommunityTweetCommunityUnavailable + + case object CommunityTweetCommunitySuspended + extends BooleanFeatureCondition(CommunityTweetCommunitySuspendedFeature) + with CommunityTweetCommunityUnavailable + + case object CommunityTweetCommunityVisible + extends BooleanFeatureCondition(CommunityTweetCommunityVisibleFeature) + + case object ViewerIsInternalCommunitiesAdmin + extends BooleanFeatureCondition(ViewerIsInternalCommunitiesAdminFeature) + + case object ViewerIsCommunityAdmin extends BooleanFeatureCondition(ViewerIsCommunityAdminFeature) + + case object ViewerIsCommunityModerator + extends BooleanFeatureCondition(ViewerIsCommunityModeratorFeature) + + case object ViewerIsCommunityMember + extends BooleanFeatureCondition(ViewerIsCommunityMemberFeature) + + sealed trait CommunityTweetIsModerated extends Condition + + case object CommunityTweetIsHidden + extends BooleanFeatureCondition(CommunityTweetIsHiddenFeature) + with CommunityTweetIsModerated + + case object CommunityTweetAuthorIsRemoved + extends BooleanFeatureCondition(CommunityTweetAuthorIsRemovedFeature) + with CommunityTweetIsModerated + + case object DoesHaveInnerCircleOfFriendsRelationship + extends BooleanFeatureCondition(HasInnerCircleOfFriendsRelationship) + + case object TweetIsCommunityConversation + extends BooleanFeatureCondition(TweetHasCommunityConversationControl) + + case object TweetIsByInvitationConversation + extends BooleanFeatureCondition(TweetHasByInvitationConversationControl) + + case object TweetIsFollowersConversation + extends BooleanFeatureCondition(TweetHasFollowersConversationControl) + + case object ViewerIsTweetConversationRootAuthor + extends BooleanFeatureCondition(TweetConversationViewerIsRootAuthor) + + private case object ViewerIsInvitedToTweetConversationByMention + extends BooleanFeatureCondition(TweetConversationViewerIsInvited) + + private case object ViewerIsInvitedToTweetConversationByReplyMention + extends BooleanFeatureCondition(TweetConversationViewerIsInvitedViaReplyMention) + + object ViewerIsInvitedToTweetConversation + extends Or( + ViewerIsInvitedToTweetConversationByMention, + ViewerIsInvitedToTweetConversationByReplyMention) + + object TweetIsExclusiveContent extends BooleanFeatureCondition(TweetIsExclusiveTweet) + object ViewerIsExclusiveTweetAuthor + extends BooleanFeatureCondition(ViewerIsExclusiveTweetRootAuthor) + object ViewerSuperFollowsExclusiveTweetAuthor + extends BooleanFeatureCondition(ViewerSuperFollowsExclusiveTweetRootAuthor) + + object TweetIsTrustedFriendsContent extends BooleanFeatureCondition(TweetIsTrustedFriendTweet) + object ViewerIsTrustedFriendsTweetAuthor + extends BooleanFeatureCondition(ViewerIsTrustedFriendTweetAuthor) + object ViewerIsTrustedFriend extends BooleanFeatureCondition(ViewerIsTrustedFriendOfTweetAuthor) + + object TweetIsCollabInvitationContent + extends BooleanFeatureCondition(TweetIsCollabInvitationTweet) + + case class TweetHasLabelForPerspectivalUser(safetyLabel: TweetSafetyLabelType) + extends Condition + with HasSafetyLabelType { + override lazy val name: String = s"TweetHasLabelForPerspectivalUser(${safetyLabel.name})" + override val features: Set[Feature[_]] = Set(TweetSafetyLabels) + override val optionalFeatures: Set[Feature[_]] = Set(ViewerId) + override val labelTypes: Set[SafetyLabelType] = Set(safetyLabel) + + private val UnsatisfiedResult: Unsatisfied = Unsatisfied(this) + private val SatisfiedResult: Satisfied = Satisfied( + FoundTweetLabelForPerspectivalUser(safetyLabel) + ) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + if (!featureMap.contains(ViewerId)) { + UnsatisfiedResult + } else { + val viewerId = featureMap(ViewerId).asInstanceOf[Long] + val labels = featureMap(TweetSafetyLabels).asInstanceOf[Seq[TweetSafetyLabel]] + labels + .collectFirst { + case label + if label.labelType == safetyLabel && label.applicableUsers.contains(viewerId) + && ExperimentBase.shouldFilterForSource(evaluationContext.params, label.source) => + SatisfiedResult + }.getOrElse(UnsatisfiedResult) + } + } + } + + case class TweetHasLabel( + safetyLabel: TweetSafetyLabelType, + labelSourceExperimentPredicate: Option[(Params, Option[LabelSource]) => Boolean] = None) + extends Condition + with HasSafetyLabelType { + override lazy val name: String = s"TweetHasLabel(${safetyLabel.name})" + override val features: Set[Feature[_]] = Set(TweetSafetyLabels) + override val optionalFeatures: Set[Feature[_]] = Set.empty + override val labelTypes: Set[SafetyLabelType] = Set(safetyLabel) + + private val UnsatisfiedResult: Unsatisfied = Unsatisfied(this) + private val SatisfiedResult: Satisfied = Satisfied(FoundTweetLabel(safetyLabel)) + + private val labelSourcePredicate: (Params, Option[LabelSource]) => Boolean = + labelSourceExperimentPredicate match { + case Some(predicate) => predicate + case _ => ExperimentBase.shouldFilterForSource + } + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val labels = featureMap(TweetSafetyLabels).asInstanceOf[Seq[TweetSafetyLabel]] + labels + .collectFirst { + case label + if label.labelType == safetyLabel + && labelSourcePredicate(evaluationContext.params, label.source) => + SatisfiedResult + }.getOrElse(UnsatisfiedResult) + } + } + + case class SpaceHasLabel( + safetyLabelType: SpaceSafetyLabelType) + extends Condition + with HasSafetyLabelType { + override lazy val name: String = s"SpaceHasLabel(${safetyLabelType.name})" + override val features: Set[Feature[_]] = Set(SpaceSafetyLabels) + override val optionalFeatures: Set[Feature[_]] = Set.empty + override val labelTypes: Set[SafetyLabelType] = Set(safetyLabelType) + + private val UnsatisfiedResult: Unsatisfied = Unsatisfied(this) + private val SatisfiedResult: Satisfied = Satisfied(FoundSpaceLabel(safetyLabelType)) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val labels = featureMap(SpaceSafetyLabels).asInstanceOf[Seq[SpaceSafetyLabel]] + labels + .collectFirst { + case label if label.safetyLabelType == safetyLabelType => + SatisfiedResult + }.getOrElse(UnsatisfiedResult) + } + } + + case class MediaHasLabel( + safetyLabelType: MediaSafetyLabelType) + extends Condition + with HasSafetyLabelType { + override lazy val name: String = s"MediaHasLabel(${safetyLabelType.name})" + override val features: Set[Feature[_]] = Set(MediaSafetyLabels) + override val optionalFeatures: Set[Feature[_]] = Set.empty + override val labelTypes: Set[SafetyLabelType] = Set(safetyLabelType) + + private val UnsatisfiedResult: Unsatisfied = Unsatisfied(this) + private val SatisfiedResult: Satisfied = Satisfied(FoundMediaLabel(safetyLabelType)) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val labels = featureMap(MediaSafetyLabels).asInstanceOf[Seq[MediaSafetyLabel]] + labels + .collectFirst { + case label if label.safetyLabelType == safetyLabelType => + SatisfiedResult + }.getOrElse(UnsatisfiedResult) + } + } + + case class TweetHasLabelWithLanguageScoreAboveThreshold( + safetyLabel: TweetSafetyLabelType, + languagesToScoreThresholds: Map[String, Double]) + extends Condition + with HasSafetyLabelType { + + override lazy val name: String = + s"TweetHasLabelWithLanguageScoreAboveThreshold(${safetyLabel.name}, ${languagesToScoreThresholds.toString})" + override val features: Set[Feature[_]] = Set(TweetSafetyLabels) + override val optionalFeatures: Set[Feature[_]] = Set.empty + override val labelTypes: Set[SafetyLabelType] = Set(safetyLabel) + + private val UnsatisfiedResult: Unsatisfied = Unsatisfied(this) + private val SatisfiedResult: Satisfied = + Satisfied( + FoundTweetLabelWithLanguageScoreAboveThreshold(safetyLabel, languagesToScoreThresholds)) + + private[this] def isAboveThreshold(label: TweetSafetyLabel) = { + val isAboveThresholdOpt = for { + modelMetadata <- label.modelMetadata + calibratedLanguage <- modelMetadata.calibratedLanguage + threshold <- languagesToScoreThresholds.get(calibratedLanguage) + score <- label.score + } yield score >= threshold + + isAboveThresholdOpt.getOrElse(false) + } + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val labels = featureMap(TweetSafetyLabels).asInstanceOf[Seq[TweetSafetyLabel]] + labels + .collectFirst { + case label + if label.labelType == safetyLabel + && isAboveThreshold(label) => + SatisfiedResult + }.getOrElse(UnsatisfiedResult) + } + } + + case class TweetHasLabelWithScoreAboveThreshold( + safetyLabel: TweetSafetyLabelType, + threshold: Double) + extends Condition + with HasSafetyLabelType { + + override lazy val name: String = + s"TweetHasLabelWithScoreAboveThreshold(${safetyLabel.name}, $threshold)" + override val features: Set[Feature[_]] = Set(TweetSafetyLabels) + override val optionalFeatures: Set[Feature[_]] = Set.empty + override val labelTypes: Set[SafetyLabelType] = Set(safetyLabel) + + private val UnsatisfiedResult = Unsatisfied(this) + private val SatisfiedResult = + Satisfied(FoundTweetLabelWithScoreAboveThreshold(safetyLabel, threshold)) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val labels = featureMap(TweetSafetyLabels).asInstanceOf[Seq[TweetSafetyLabel]] + labels + .collectFirst { + case label + if label.labelType == safetyLabel + && label.score.exists(_ >= threshold) => + SatisfiedResult + }.getOrElse(UnsatisfiedResult) + } + } + + case class TweetHasLabelWithScoreAboveThresholdWithParam( + safetyLabel: TweetSafetyLabelType, + thresholdParam: Param[Double]) + extends Condition + with HasSafetyLabelType + with HasParams { + override lazy val name: String = + s"TweetHasLabelWithScoreAboveThreshold(${safetyLabel.name}, ${NamingUtils.getFriendlyName(thresholdParam)})" + override val features: Set[Feature[_]] = Set(TweetSafetyLabels) + override val optionalFeatures: Set[Feature[_]] = Set.empty + override val labelTypes: Set[SafetyLabelType] = Set(safetyLabel) + private val UnsatisfiedResult = Unsatisfied(this) + override val params: Set[Param[_]] = Set(thresholdParam) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val labels = featureMap(TweetSafetyLabels).asInstanceOf[Seq[TweetSafetyLabel]] + val threshold = evaluationContext.params(thresholdParam) + val SatisfiedResult = + Satisfied(FoundTweetLabelWithScoreAboveThreshold(safetyLabel, threshold)) + labels + .collectFirst { + case label + if label.labelType == safetyLabel + && label.score.exists(_ >= threshold) => + SatisfiedResult + }.getOrElse(UnsatisfiedResult) + } + } + + case class TweetHasLabelWithLanguageIn( + safetyLabelType: TweetSafetyLabelType, + languages: Set[String]) + extends Condition + with HasSafetyLabelType { + + override lazy val name: String = + s"TweetHasLabelWithLanguageIn($safetyLabelType, $languages)" + override val features: Set[Feature[_]] = Set(TweetSafetyLabels) + override val optionalFeatures: Set[Feature[_]] = Set.empty + override val labelTypes: Set[SafetyLabelType] = Set(safetyLabelType) + + private val UnsatisfiedResult: Unsatisfied = Unsatisfied(this) + private val SatisfiedResult: Satisfied = + Satisfied(FoundTweetLabelWithLanguageIn(safetyLabelType, languages)) + + private[this] def hasLanguageMatch(label: TweetSafetyLabel): Boolean = { + val isMatchingLanguageOpt = for { + metadata <- label.modelMetadata + language <- metadata.calibratedLanguage + } yield languages.contains(language) + isMatchingLanguageOpt.getOrElse(false) + } + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + featureMap(TweetSafetyLabels) + .asInstanceOf[Seq[TweetSafetyLabel]] + .collectFirst { + case label if label.labelType == safetyLabelType && hasLanguageMatch(label) => + SatisfiedResult + }.getOrElse(UnsatisfiedResult) + } + } + + case class TweetHasLabelWithLanguagesWithParam( + safetyLabelType: TweetSafetyLabelType, + languageParam: Param[Seq[String]]) + extends Condition + with HasSafetyLabelType + with HasParams { + override lazy val name: String = + s"TweetHasLabelWithLanguageIn($safetyLabelType, ${NamingUtils.getFriendlyName(languageParam)})" + override val features: Set[Feature[_]] = Set(TweetSafetyLabels) + override val optionalFeatures: Set[Feature[_]] = Set.empty + override val labelTypes: Set[SafetyLabelType] = Set(safetyLabelType) + override val params: Set[Param[_]] = Set(languageParam) + + private val UnsatisfiedResult: Unsatisfied = Unsatisfied(this) + + private[this] def hasLanguageMatch(label: TweetSafetyLabel, languages: Set[String]): Boolean = { + val isMatchingLanguageOpt = for { + metadata <- label.modelMetadata + language <- metadata.calibratedLanguage + } yield languages.contains(language) + isMatchingLanguageOpt.getOrElse(false) + } + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val languages = evaluationContext.params(languageParam).toSet + val SatisfiedResult: Satisfied = + Satisfied(FoundTweetLabelWithLanguageIn(safetyLabelType, languages)) + featureMap(TweetSafetyLabels) + .asInstanceOf[Seq[TweetSafetyLabel]] + .collectFirst { + case label if label.labelType == safetyLabelType && hasLanguageMatch(label, languages) => + SatisfiedResult + }.getOrElse(UnsatisfiedResult) + } + } + + type TweetSafetyLabelPredicateFn = (TweetSafetyLabel) => Boolean + abstract class NamedTweetSafetyLabelPredicate( + private[rules] val fn: TweetSafetyLabelPredicateFn, + private[rules] val name: String) + + abstract class TweetHasSafetyLabelWithPredicate( + private[rules] val safetyLabelType: TweetSafetyLabelType, + private[rules] val predicate: NamedTweetSafetyLabelPredicate) + extends Condition + with HasSafetyLabelType { + + override lazy val name: String = + s"TweetHasSafetyLabelWithPredicate(${predicate.name}($safetyLabelType))" + override val features: Set[Feature[_]] = Set(TweetSafetyLabels) + override val optionalFeatures: Set[Feature[_]] = Set.empty + override val labelTypes: Set[SafetyLabelType] = Set(safetyLabelType) + + private val UnsatisfiedResult: Unsatisfied = Unsatisfied(this) + private val SatisfiedResult: Satisfied = + Satisfied(Result.FoundTweetSafetyLabelWithPredicate(safetyLabelType, predicate.name)) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + featureMap(TweetSafetyLabels) + .asInstanceOf[Seq[TweetSafetyLabel]] + .collectFirst { + case label if label.labelType == safetyLabelType && predicate.fn(label) => + SatisfiedResult + }.getOrElse(UnsatisfiedResult) + } + } + + object TweetHasSafetyLabelWithPredicate { + def unapply( + condition: TweetHasSafetyLabelWithPredicate + ): Option[(TweetSafetyLabelType, NamedTweetSafetyLabelPredicate)] = + Some((condition.safetyLabelType, condition.predicate)) + } + + case class WithScoreEqInt(score: Int) + extends NamedTweetSafetyLabelPredicate( + fn = tweetSafetyLabel => tweetSafetyLabel.score.exists(s => s.intValue() == score), + name = "WithScoreEqInt" + ) + case class TweetHasSafetyLabelWithScoreEqInt( + override val safetyLabelType: TweetSafetyLabelType, + score: Int) + extends TweetHasSafetyLabelWithPredicate( + safetyLabelType, + predicate = WithScoreEqInt(score) + ) + + case class TweetReplyToParentTweetBeforeDuration(duration: Duration) extends Condition { + override val features: Set[Feature[_]] = Set(TweetParentId, TweetTimestamp) + override val optionalFeatures: Set[Feature[_]] = Set.empty + + private val UnsatisfiedResult: Unsatisfied = Unsatisfied(this) + private val SatisfiedResult: Satisfied = Satisfied( + Result.IsTweetReplyToParentTweetBeforeDuration(duration)) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + featureMap + .get(TweetParentId).collect { + case tweetParentId: Long => + featureMap + .get(TweetTimestamp).collect { + case tweetTimestamp: Time + if tweetTimestamp.diff(SnowflakeId.timeFromId(tweetParentId)) < duration => + SatisfiedResult + }.getOrElse(UnsatisfiedResult) + }.getOrElse(UnsatisfiedResult) + } + } + + case class TweetReplyToRootTweetBeforeDuration(duration: Duration) extends Condition { + override val features: Set[Feature[_]] = Set(TweetConversationId, TweetTimestamp) + override val optionalFeatures: Set[Feature[_]] = Set.empty + + private val UnsatisfiedResult: Unsatisfied = Unsatisfied(this) + private val SatisfiedResult: Satisfied = Satisfied( + Result.IsTweetReplyToRootTweetBeforeDuration(duration)) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + featureMap + .get(TweetConversationId).collect { + case tweetConversationId: Long => + featureMap + .get(TweetTimestamp).collect { + case tweetTimestamp: Time + if tweetTimestamp.diff( + SnowflakeId.timeFromId(tweetConversationId)) < duration => + SatisfiedResult + }.getOrElse(UnsatisfiedResult) + }.getOrElse(UnsatisfiedResult) + } + } + + case class TweetComposedBefore(cutoffTimestamp: Time) extends Condition { + assert(cutoffTimestamp.inMilliseconds > SnowflakeId.FirstSnowflakeIdUnixTime) + + override val features: Set[Feature[_]] = Set(TweetTimestamp) + override val optionalFeatures: Set[Feature[_]] = Set.empty + + private val UnsatisfiedResult: Unsatisfied = Unsatisfied(this) + private val SatisfiedResult: Satisfied = Satisfied(HasTweetTimestampBeforeCutoff) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + featureMap(TweetTimestamp) match { + case timestamp: Time if timestamp > cutoffTimestamp => UnsatisfiedResult + case _ => SatisfiedResult + } + } + } + + case class TweetComposedAfter(cutoffTimestamp: Time) extends Condition { + assert(cutoffTimestamp.inMilliseconds > SnowflakeId.FirstSnowflakeIdUnixTime) + + override val features: Set[Feature[_]] = Set(TweetTimestamp) + override val optionalFeatures: Set[Feature[_]] = Set.empty + + private val UnsatisfiedResult: Unsatisfied = Unsatisfied(this) + private val SatisfiedResult: Satisfied = Satisfied(HasTweetTimestampAfterCutoff) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + featureMap(TweetTimestamp) match { + case timestamp: Time if timestamp > cutoffTimestamp => SatisfiedResult + case _ => UnsatisfiedResult + } + } + } + + case class TweetComposedAfterOffset(offset: Duration) extends Condition { + override val features: Set[Feature[_]] = Set(TweetTimestamp) + override val optionalFeatures: Set[Feature[_]] = Set.empty + + private val UnsatisfiedResult: Unsatisfied = Unsatisfied(this) + private val SatisfiedResult: Satisfied = Satisfied(HasTweetTimestampAfterOffset) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + featureMap(TweetTimestamp) match { + case timestamp: Time if timestamp > Time.now.minus(offset) => SatisfiedResult + case _ => UnsatisfiedResult + } + } + } + + case class TweetComposedAfterWithParam(cutoffTimeParam: Param[Time]) + extends Condition + with HasParams { + override val features: Set[Feature[_]] = Set(TweetTimestamp) + override val optionalFeatures: Set[Feature[_]] = Set.empty + override val params: Set[Param[_]] = Set(cutoffTimeParam) + private val UnsatisfiedResult: Unsatisfied = Unsatisfied(this) + private val SatisfiedResult: Satisfied = Satisfied(HasTweetTimestampAfterCutoff) + + override def preFilter( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): PreFilterResult = { + val cutoffTimestamp = evaluationContext.params(cutoffTimeParam) + if (cutoffTimestamp.inMilliseconds < SnowflakeId.FirstSnowflakeIdUnixTime) { + Filtered + } else { + super.preFilter(evaluationContext, featureMap) + } + } + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val cutoffTimestamp = evaluationContext.params(cutoffTimeParam) + featureMap(TweetTimestamp) match { + case _: Time if cutoffTimestamp.inMilliseconds < SnowflakeId.FirstSnowflakeIdUnixTime => + UnsatisfiedResult + case timestamp: Time if timestamp > cutoffTimestamp => SatisfiedResult + case _ => UnsatisfiedResult + } + } + } + + case class AuthorHasLabel(labelValue: UserLabelValue, shortCircuitable: Boolean = true) + extends Condition + with HasSafetyLabelType { + override lazy val name: String = s"AuthorHasLabel(${labelValue.name})" + override val features: Set[Feature[_]] = Set(AuthorUserLabels) + override val optionalFeatures: Set[Feature[_]] = Set.empty + override val labelTypes: Set[SafetyLabelType] = Set(labelValue) + + private val UnsatisfiedResult: Unsatisfied = Unsatisfied(this) + private val SatisfiedResult: Satisfied = Satisfied(FoundUserLabel(labelValue)) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val labels = featureMap(AuthorUserLabels).asInstanceOf[Seq[Label]].map(UserLabel.fromThrift) + labels + .collectFirst { + case label + if label.labelValue == labelValue + && ExperimentBase.shouldFilterForSource(evaluationContext.params, label.source) => + SatisfiedResult + }.getOrElse(UnsatisfiedResult) + } + } + + abstract class ViewerHasRole(role: String) extends Condition { + override lazy val name: String = s"ViewerHasRole(${role})" + override val features: Set[Feature[_]] = Set(ViewerRoles) + override val optionalFeatures: Set[Feature[_]] = Set.empty + + private val UnsatisfiedResult: Unsatisfied = Unsatisfied(this) + private val SatisfiedResult: Satisfied = Satisfied(FoundUserRole(role)) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val roles = featureMap(ViewerRoles).asInstanceOf[Seq[String]] + if (roles.contains(role)) { + SatisfiedResult + } else { + UnsatisfiedResult + } + } + } + + case object ViewerIsEmployee extends ViewerHasRole(ViewerRoles.EmployeeRole) + + case class ViewerHasLabel(labelValue: UserLabelValue) extends Condition with HasSafetyLabelType { + override lazy val name: String = s"ViewerHasLabel(${labelValue.name})" + override val features: Set[Feature[_]] = Set(ViewerUserLabels) + override val optionalFeatures: Set[Feature[_]] = Set.empty + override val labelTypes: Set[SafetyLabelType] = Set(labelValue) + + private val UnsatisfiedResult: Unsatisfied = Unsatisfied(this) + private val SatisfiedResult: Satisfied = Satisfied(FoundUserLabel(labelValue)) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val labels = featureMap(ViewerUserLabels).asInstanceOf[Seq[Label]].map(UserLabel.fromThrift) + labels + .collectFirst { + case label + if label.labelValue == labelValue + && ExperimentBase.shouldFilterForSource(evaluationContext.params, label.source) => + SatisfiedResult + }.getOrElse(UnsatisfiedResult) + } + } + + case object DeactivatedAuthor extends BooleanFeatureCondition(AuthorIsDeactivated) + case object ErasedAuthor extends BooleanFeatureCondition(AuthorIsErased) + case object OffboardedAuthor extends BooleanFeatureCondition(AuthorIsOffboarded) + case object ProtectedAuthor extends BooleanFeatureCondition(AuthorIsProtected) + case object VerifiedAuthor extends BooleanFeatureCondition(AuthorIsVerified) + case object NsfwUserAuthor extends BooleanFeatureCondition(AuthorIsNsfwUser) + case object NsfwAdminAuthor extends BooleanFeatureCondition(AuthorIsNsfwAdmin) + case object TweetHasNsfwUserAuthor extends BooleanFeatureCondition(TweetHasNsfwUser) + case object TweetHasNsfwAdminAuthor extends BooleanFeatureCondition(TweetHasNsfwAdmin) + case object TweetHasMedia extends BooleanFeatureCondition(TweetHasMediaFeature) + case object TweetHasDmcaMedia extends BooleanFeatureCondition(HasDmcaMediaFeature) + case object TweetHasCard extends BooleanFeatureCondition(TweetHasCardFeature) + case object IsPollCard extends BooleanFeatureCondition(CardIsPoll) + + case object ProtectedViewer extends BooleanFeatureCondition(ViewerIsProtected) + case object SoftViewer extends BooleanFeatureCondition(ViewerIsSoftUser) + + case object ViewerHasUqfEnabled + extends BooleanFeatureCondition(ViewerHasUniversalQualityFilterEnabled) + + abstract class ViewerHasMatchingKeywordFor(muteSurface: MuteSurface) extends Condition { + override def features: Set[Feature[_]] = Set(feature) + override val optionalFeatures: Set[Feature[_]] = Set.empty + + private val UnsatisfiedResult = Unsatisfied(this) + + private val feature: Feature[MutedKeyword] = muteSurface match { + case MuteSurface.HomeTimeline => ViewerMutesKeywordInTweetForHomeTimeline + case MuteSurface.Notifications => ViewerMutesKeywordInTweetForNotifications + case MuteSurface.TweetReplies => ViewerMutesKeywordInTweetForTweetReplies + + case _ => throw new NoSuchElementException(muteSurface.toString) + } + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val mutedKeyword = featureMap(feature) + .asInstanceOf[MutedKeyword] + Result.fromMutedKeyword(mutedKeyword, UnsatisfiedResult) + } + } + + case object ViewerHasMatchingKeywordForHomeTimeline + extends ViewerHasMatchingKeywordFor(MuteSurface.HomeTimeline) + + case object ViewerHasMatchingKeywordForNotifications + extends ViewerHasMatchingKeywordFor(MuteSurface.Notifications) + + case object ViewerHasMatchingKeywordForTweetReplies + extends ViewerHasMatchingKeywordFor(MuteSurface.TweetReplies) + + case object ViewerHasMatchingKeywordForAllSurfaces extends Condition { + override def features: Set[Feature[_]] = Set(ViewerMutesKeywordInTweetForAllSurfaces) + override val optionalFeatures: Set[Feature[_]] = Set.empty + private val UnsatisfiedResult = Unsatisfied(this) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val mutedKeyword = featureMap(ViewerMutesKeywordInTweetForAllSurfaces) + .asInstanceOf[MutedKeyword] + Result.fromMutedKeyword(mutedKeyword, UnsatisfiedResult) + } + } + + abstract class ViewerHasMatchingKeywordInSpaceTitleFor(muteSurface: MuteSurface) + extends Condition { + override def features: Set[Feature[_]] = Set(feature) + override val optionalFeatures: Set[Feature[_]] = Set.empty + + private val UnsatisfiedResult = Unsatisfied(this) + + private val feature: Feature[MutedKeyword] = muteSurface match { + case MuteSurface.Notifications => ViewerMutesKeywordInSpaceTitleForNotifications + case _ => throw new NoSuchElementException(muteSurface.toString) + } + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val mutedKeyword = featureMap(feature) + .asInstanceOf[MutedKeyword] + Result.fromMutedKeyword(mutedKeyword, UnsatisfiedResult) + } + } + + case object ViewerHasMatchingKeywordInSpaceTitleForNotifications + extends ViewerHasMatchingKeywordInSpaceTitleFor(MuteSurface.Notifications) + + case object ViewerFiltersNoConfirmedEmail + extends BooleanFeatureCondition( + com.twitter.visibility.features.ViewerFiltersNoConfirmedEmail + ) + + case object ViewerFiltersNoConfirmedPhone + extends BooleanFeatureCondition( + com.twitter.visibility.features.ViewerFiltersNoConfirmedPhone + ) + + case object ViewerFiltersDefaultProfileImage + extends BooleanFeatureCondition( + com.twitter.visibility.features.ViewerFiltersDefaultProfileImage + ) + + case object ViewerFiltersNewUsers + extends BooleanFeatureCondition( + com.twitter.visibility.features.ViewerFiltersNewUsers + ) + + case object ViewerFiltersNotFollowedBy + extends BooleanFeatureCondition( + com.twitter.visibility.features.ViewerFiltersNotFollowedBy + ) + + case object AuthorHasConfirmedEmail + extends BooleanFeatureCondition( + com.twitter.visibility.features.AuthorHasConfirmedEmail + ) + + case object AuthorHasVerifiedPhone + extends BooleanFeatureCondition( + com.twitter.visibility.features.AuthorHasVerifiedPhone + ) + + case object AuthorHasDefaultProfileImage + extends BooleanFeatureCondition( + com.twitter.visibility.features.AuthorHasDefaultProfileImage + ) + + case object AuthorIsNewAccount extends Condition { + override val features: Set[Feature[_]] = Set(AuthorAccountAge) + override val optionalFeatures: Set[Feature[_]] = Set.empty + + private val UnsatisfiedResult = Unsatisfied(this) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val age = featureMap(AuthorAccountAge).asInstanceOf[Duration] + + if (age < 72.hours) { + Result.SatisfiedResult + } else { + UnsatisfiedResult + } + } + } + + abstract class ViewerInJurisdiction extends Condition { + override def features: Set[Feature[_]] = Set.empty + override val optionalFeatures: Set[Feature[_]] = Set(RequestCountryCode, ViewerCountryCode) + + protected val unsatisfiedResult = Unsatisfied(this) + + protected case class CountryFeatures( + requestCountryCode: Option[String], + viewerCountryCode: Option[String]) + + def getCountryFeatures(featureMap: Map[Feature[_], _]): CountryFeatures = { + val requestCountryCodeOpt = featureMap + .get(RequestCountryCode) + .map(_.asInstanceOf[String]) + val viewerCountryCodeOpt = featureMap + .get(ViewerCountryCode) + .map(_.asInstanceOf[String]) + + CountryFeatures(requestCountryCodeOpt, viewerCountryCodeOpt) + } + } + + case class ViewerInHrcj(jurisdictions: Set[String]) extends ViewerInJurisdiction { + + override def preFilter( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): PreFilterResult = + featureMap + .get(RequestCountryCode) + .map(_.asInstanceOf[String]) + .collectFirst { + case rcc if jurisdictions.contains(rcc) => NeedsFullEvaluation + } + .getOrElse(Filtered) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val countryFeatures = getCountryFeatures(featureMap) + + countryFeatures match { + case CountryFeatures(Some(rcc), Some(vcc)) + if jurisdictions.contains(rcc) && vcc.equals(rcc) => + Satisfied(Result.ViewerInHrcj(rcc)) + case _ => unsatisfiedResult + } + } + } + + case class ViewerOrRequestInJurisdiction(enabledCountriesParam: Param[Seq[String]]) + extends ViewerInJurisdiction + with HasParams + with PreFilterOnOptionalFeatures { + + override val params: Set[Param[_]] = Set(enabledCountriesParam) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val countries: Seq[String] = + evaluationContext.params(enabledCountriesParam).map(c => c.toLowerCase) + val countryFeatures = getCountryFeatures(featureMap) + + val countryCodeOpt = + countryFeatures.viewerCountryCode.orElse(countryFeatures.requestCountryCode) + + countryCodeOpt match { + case Some(countryCode) if countries.contains(countryCode) => + Satisfied(Result.ViewerOrRequestInCountry(countryCode)) + case _ => unsatisfiedResult + } + } + } + + case class ViewerAgeInYearsGte(ageToCompare: Int, ignoreEmptyAge: Boolean = false) + extends Condition + with PreFilterOnOptionalFeatures { + override def features: Set[Feature[_]] = Set.empty + override def optionalFeatures: Set[Feature[_]] = Set(ViewerAge) + + private val unsatisfiedResult = Unsatisfied(this) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = + featureMap + .get(ViewerAge) + .map(_.asInstanceOf[UserAge]) + .collectFirst { + case UserAge(Some(age)) if age >= ageToCompare => + Satisfied(Result.ViewerAgeInYears(age)) + case UserAge(None) if ignoreEmptyAge => + Satisfied(Result.NoViewerAge) + } + .getOrElse(unsatisfiedResult) + } + + case class UnderageViewer(ageToCompare: Int) extends Condition with PreFilterOnOptionalFeatures { + override def features: Set[Feature[_]] = Set.empty + override def optionalFeatures: Set[Feature[_]] = Set(ViewerAge) + + private val unsatisfiedResult = Unsatisfied(this) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = + featureMap + .get(ViewerAge) + .map(_.asInstanceOf[UserAge]) + .collectFirst { + case UserAge(Some(age)) if age < ageToCompare => Satisfied(Result.ViewerAgeInYears(age)) + } + .getOrElse(unsatisfiedResult) + } + + case object ViewerMissingAge extends Condition with PreFilterOnOptionalFeatures { + override def features: Set[Feature[_]] = Set.empty + override def optionalFeatures: Set[Feature[_]] = Set(ViewerAge) + + private val unsatisfiedResult = Unsatisfied(this) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = + featureMap + .get(ViewerAge) + .map(_.asInstanceOf[UserAge]) + .collectFirst { + case UserAge(None) => Satisfied(Result.NoViewerAge) + } + .getOrElse(unsatisfiedResult) + } + + case object ViewerOptInBlockingOnSearch extends BooleanFeatureCondition(ViewerOptInBlocking) + case object ViewerOptInFilteringOnSearch extends BooleanFeatureCondition(ViewerOptInFiltering) + case object SelfReply extends BooleanFeatureCondition(TweetIsSelfReply) + case object Nullcast extends BooleanFeatureCondition(TweetIsNullcast) + case object Moderated extends BooleanFeatureCondition(TweetIsModerated) + case object Retweet extends BooleanFeatureCondition(TweetIsRetweet) + + case object IsFirstPageSearchResult extends Condition { + override val features: Set[Feature[_]] = Set(SearchResultsPageNumber) + override val optionalFeatures: Set[Feature[_]] = Set.empty + private val UnsatisfiedResult = Unsatisfied(this) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val searchResultsPageNumber = featureMap(SearchResultsPageNumber).asInstanceOf[Int] + if (searchResultsPageNumber == 1) { + Result.SatisfiedResult + } else { + UnsatisfiedResult + } + } + } + + case object HasSearchCandidateCountGreaterThan45 extends Condition { + override val features: Set[Feature[_]] = Set(SearchCandidateCount) + override val optionalFeatures: Set[Feature[_]] = Set.empty + private val UnsatisfiedResult = Unsatisfied(this) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val searchCandidateCount = featureMap(SearchCandidateCount).asInstanceOf[Int] + if (searchCandidateCount > 45) { + Result.SatisfiedResult + } else { + UnsatisfiedResult + } + } + } + + abstract class HasSearchQuerySource(querySourceToMatch: ThriftQuerySource) extends Condition { + override lazy val name: String = s"HasSearchQuerySource(${querySourceToMatch})" + override val features: Set[Feature[_]] = Set(SearchQuerySource) + override val optionalFeatures: Set[Feature[_]] = Set.empty + private val UnsatisfiedResult = Unsatisfied(this) + private val SatisfiedResult: Satisfied = Satisfied(HasQuerySource(querySourceToMatch)) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val querySource = featureMap(SearchQuerySource).asInstanceOf[ThriftQuerySource] + if (querySourceToMatch.equals(querySource)) { + SatisfiedResult + } else { + UnsatisfiedResult + } + } + } + + case object IsTrendClickSourceSearchResult extends Condition { + override val features: Set[Feature[_]] = Set(SearchQuerySource) + override val optionalFeatures: Set[Feature[_]] = Set.empty + private val UnsatisfiedResult = Unsatisfied(this) + + private def checkQuerySource[T]( + featureMap: Map[Feature[_], _], + nonTrendSourceResult: T, + trendSourceResult: T + ): T = { + val searchResultsPageNumber = featureMap(SearchQuerySource).asInstanceOf[ThriftQuerySource] + if (searchResultsPageNumber == ThriftQuerySource.TrendClick) { + trendSourceResult + } else { + nonTrendSourceResult + } + } + + override def preFilter( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): PreFilterResult = + checkQuerySource(featureMap, Filtered, NotFiltered) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = + checkQuerySource(featureMap, UnsatisfiedResult, Result.SatisfiedResult) + } + case object IsSearchHashtagClick extends HasSearchQuerySource(ThriftQuerySource.HashtagClick) + case object IsSearchTrendClick extends HasSearchQuerySource(ThriftQuerySource.TrendClick) + + case object SearchQueryHasUser + extends BooleanFeatureCondition(com.twitter.visibility.features.SearchQueryHasUser) + + case class Equals[T](feature: Feature[T], value: T) extends Condition { + + override def features: Set[Feature[_]] = Set(feature) + override val optionalFeatures: Set[Feature[_]] = Set.empty + + private val SatisfiedResult: Result = Satisfied() + private val UnsatisfiedResult: Result = Unsatisfied(this) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val featureValue = featureMap(feature).asInstanceOf[T] + if (featureValue.equals(value)) SatisfiedResult else UnsatisfiedResult + } + } + + case class FeatureEquals[T](left: Feature[T], right: Feature[T]) extends Condition { + + override def features: Set[Feature[_]] = Set.empty + override val optionalFeatures: Set[Feature[_]] = Set(left, right) + + private val SatisfiedResult: Result = Satisfied() + private val UnsatisfiedResult: Result = Unsatisfied(this) + + override def preFilter( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): PreFilterResult = { + if (featureMap.contains(left) && featureMap.contains(right)) { + if (apply(evaluationContext, featureMap).asBoolean) { + NotFiltered + } else { + Filtered + } + } else { + NeedsFullEvaluation + } + } + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + if (featureMap.contains(left) && featureMap.contains(right)) { + val leftValue = featureMap(left).asInstanceOf[T] + val rightValue = featureMap(right).asInstanceOf[T] + if (leftValue.equals(rightValue)) SatisfiedResult else UnsatisfiedResult + } else { + UnsatisfiedResult + } + } + } + + case class And(override val conditions: Condition*) + extends Condition + with HasNestedConditions + with HasParams { + override lazy val name: String = s"(${conditions.map(_.name).mkString(" And ")})" + override val features: Set[Feature[_]] = conditions.flatMap(_.features).toSet + override val optionalFeatures: Set[Feature[_]] = conditions.flatMap(_.optionalFeatures).toSet + override val params: Set[Param[_]] = + conditions.collect { case p: HasParams => p.params }.flatten.toSet + + override def preFilter( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): PreFilterResult = { + conditions.foldLeft(NotFiltered: PreFilterResult) { + case (NotFiltered, condition) => condition.preFilter(evaluationContext, featureMap) + case (Filtered, _) => Filtered + case (NeedsFullEvaluation, condition) => { + condition.preFilter(evaluationContext, featureMap) match { + case Filtered => Filtered + case _ => NeedsFullEvaluation + } + } + } + } + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + conditions.foldLeft(Result.SatisfiedResult) { + case (result @ Unsatisfied(_), _) => result + case (Result.SatisfiedResult, condition) => condition.apply(evaluationContext, featureMap) + case (result @ Satisfied(_), condition) => { + condition.apply(evaluationContext, featureMap) match { + case r @ Unsatisfied(_) => r + case _ => result + } + } + } + } + } + + case class Or(override val conditions: Condition*) + extends Condition + with HasNestedConditions + with HasParams { + override lazy val name: String = s"(${conditions.map(_.name).mkString(" Or ")})" + override val features: Set[Feature[_]] = conditions.flatMap(_.features).toSet + override val optionalFeatures: Set[Feature[_]] = conditions.flatMap(_.optionalFeatures).toSet + override val params: Set[Param[_]] = + conditions.collect { case p: HasParams => p.params }.flatten.toSet + + private val UnsatisfiedResult = Unsatisfied(this) + + override def preFilter( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): PreFilterResult = { + conditions.foldLeft(Filtered: PreFilterResult) { + case (Filtered, c) => c.preFilter(evaluationContext, featureMap) + case (NotFiltered, _) => NotFiltered + case (NeedsFullEvaluation, c) => { + c.preFilter(evaluationContext, featureMap) match { + case NotFiltered => NotFiltered + case _ => NeedsFullEvaluation + } + } + } + } + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val foundSatisfiedCondition = + conditions.find(_.apply(evaluationContext, featureMap).asBoolean) + if (foundSatisfiedCondition.isDefined) { + Result.SatisfiedResult + } else { + UnsatisfiedResult + } + } + } + + case class Not(condition: Condition) extends Condition with HasNestedConditions with HasParams { + override lazy val name: String = s"Not(${condition.name})" + override val features: Set[Feature[_]] = condition.features + override val optionalFeatures: Set[Feature[_]] = condition.optionalFeatures + override val conditions: Seq[Condition] = Seq(condition) + override val params: Set[Param[_]] = + conditions.collect { case p: HasParams => p.params }.flatten.toSet + + private val UnsatisfiedResult = Unsatisfied(this) + + override def preFilter( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): PreFilterResult = + condition.preFilter(evaluationContext, featureMap) match { + case Filtered => NotFiltered + case NotFiltered => Filtered + case _ => NeedsFullEvaluation + } + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = + if (condition(evaluationContext, featureMap).asBoolean) { + UnsatisfiedResult + } else { + Result.SatisfiedResult + } + } + + val LoggedOutOrViewerNotFollowingAuthor: And = + And(NonAuthorViewer, Or(LoggedOutViewer, Not(ViewerDoesFollowAuthor))) + + val LoggedOutOrViewerOptInFiltering: Or = + Or(LoggedOutViewer, ViewerOptInFilteringOnSearch) + + val LoggedInViewer: Not = Not(LoggedOutViewer) + + val OuterAuthorNotFollowingAuthor: And = + And(Not(OuterAuthorIsInnerAuthor), Not(OuterAuthorFollowsAuthor)) + + val IsFocalTweet: FeatureEquals[Long] = FeatureEquals(TweetId, FocalTweetId) + + val NonHydratingConditions: Set[Class[_]] = Set( + LoggedOutViewer, + NonAuthorViewer, + True, + TweetComposedAfter(Time.now), + TweetComposedBefore(Time.now) + ).map(_.getClass) + + trait HasParams { + val params: Set[Param[_]] + } + + def hasLabelCondition(condition: Condition, tweetSafetyLabelType: TweetSafetyLabelType): Boolean = + condition match { + case lt: HasSafetyLabelType => + lt.hasLabelType(tweetSafetyLabelType) + case _ => false + } + + def hasLabelCondition(condition: Condition, userLabelValue: UserLabelValue): Boolean = + condition match { + case lt: HasSafetyLabelType => + lt.hasLabelType(userLabelValue) + case _ => false + } + + def hasLabelCondition(condition: Condition, spaceSafetyLabelType: SpaceSafetyLabelType): Boolean = + condition match { + case lt: HasSafetyLabelType => + lt.hasLabelType(spaceSafetyLabelType) + case _ => false + } + + def hasLabelCondition(condition: Condition, mediaSafetyLabelType: MediaSafetyLabelType): Boolean = + condition match { + case lt: HasSafetyLabelType => + lt.hasLabelType(mediaSafetyLabelType) + case _ => false + } + + case class Choose[T]( + conditionMap: Map[T, Condition], + defaultCondition: Condition, + choiceParam: Param[T]) + extends Condition + with HasNestedConditions + with HasParams { + override lazy val name: String = + s"(Either ${conditionMap.values.map(_.name).mkString(", ")} or ${defaultCondition.name})" + override val features: Set[Feature[_]] = + conditionMap.values.flatMap(_.features).toSet ++ defaultCondition.features + override val optionalFeatures: Set[Feature[_]] = + conditionMap.values.flatMap(_.optionalFeatures).toSet ++ defaultCondition.optionalFeatures + override val conditions: Seq[Condition] = conditionMap.values.toSeq :+ defaultCondition + override val params: Set[Param[_]] = + conditions.collect { case p: HasParams => p.params }.flatten.toSet + + private[this] def getCondition(evaluationContext: EvaluationContext): Condition = + conditionMap.getOrElse(evaluationContext.params(choiceParam), defaultCondition) + + override def preFilter( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): PreFilterResult = + getCondition(evaluationContext).preFilter(evaluationContext, featureMap) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = + getCondition(evaluationContext)(evaluationContext, featureMap) + } + + case class IfElse( + branchingCondition: Condition, + ifTrueCondition: Condition, + ifFalseCondition: Condition) + extends Condition + with HasNestedConditions + with HasParams { + override lazy val name: String = + s"(If ${branchingCondition.name} Then ${ifTrueCondition.name} Else ${ifFalseCondition.name})" + override val features: Set[Feature[_]] = + branchingCondition.features ++ ifTrueCondition.features ++ ifFalseCondition.features + override val optionalFeatures: Set[Feature[_]] = + branchingCondition.optionalFeatures ++ ifTrueCondition.optionalFeatures ++ ifFalseCondition.optionalFeatures + override val conditions: Seq[Condition] = + Seq(branchingCondition, ifTrueCondition, ifFalseCondition) + override val params: Set[Param[_]] = + conditions.collect { case p: HasParams => p.params }.flatten.toSet + + override def preFilter( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): PreFilterResult = + branchingCondition.preFilter(evaluationContext, featureMap) match { + case Filtered => + ifFalseCondition.preFilter(evaluationContext, featureMap) + case NotFiltered => + ifTrueCondition.preFilter(evaluationContext, featureMap) + case _ => + NeedsFullEvaluation + } + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = + if (branchingCondition(evaluationContext, featureMap).asBoolean) { + ifTrueCondition(evaluationContext, featureMap) + } else { + ifFalseCondition(evaluationContext, featureMap) + } + } + + case class GatedAlternate[T]( + defaultCondition: Condition, + alternateConditions: Map[T, Condition], + bucketIdentifierToUseOnDisagreementParam: Param[Option[T]]) + extends Condition + with HasNestedConditions + with HasParams { + + override lazy val name: String = + s"(${defaultCondition.name} or sometimes ${alternateConditions.values.map(_.name).mkString(" or ")})" + + override val features: Set[Feature[_]] = + defaultCondition.features ++ alternateConditions.values.flatMap(_.features) + + override val optionalFeatures: Set[Feature[_]] = + defaultCondition.optionalFeatures ++ alternateConditions.values.flatMap(_.optionalFeatures) + + override val conditions: Seq[Condition] = Seq(defaultCondition) ++ alternateConditions.values + + override val params: Set[Param[_]] = + conditions.collect { case p: HasParams => p.params }.flatten.toSet + + override def preFilter( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): PreFilterResult = + if (defaultCondition.preFilter(evaluationContext, featureMap) == Filtered && + alternateConditions.values.forall(_.preFilter(evaluationContext, featureMap) == Filtered)) { + Filtered + } else if (defaultCondition.preFilter(evaluationContext, featureMap) == NotFiltered && + alternateConditions.values.forall( + _.preFilter(evaluationContext, featureMap) == NotFiltered)) { + NotFiltered + } else { + NeedsFullEvaluation + } + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val defaultConditionResult: Result = defaultCondition(evaluationContext, featureMap) + val alternateConditionResult: Map[T, Result] = + alternateConditions.mapValues(_(evaluationContext, featureMap)) + + if (alternateConditionResult.values.exists(_.asBoolean != defaultConditionResult.asBoolean)) { + evaluationContext.params(bucketIdentifierToUseOnDisagreementParam) match { + case Some(bucket) if alternateConditionResult.contains(bucket) => + alternateConditionResult(bucket) + case _ => + defaultConditionResult + } + } else { + defaultConditionResult + } + } + } + + case class EnumGatedAlternate[E <: Enumeration]( + defaultCondition: Condition, + alternateConditions: Map[E#Value, Condition], + bucketIdentifierToUseOnDisagreementParam: EnumParam[E]) + extends Condition + with HasNestedConditions + with HasParams { + + override lazy val name: String = + s"(${defaultCondition.name} or sometimes ${alternateConditions.values.map(_.name).mkString(" or ")})" + + override val features: Set[Feature[_]] = + defaultCondition.features ++ alternateConditions.values.flatMap(_.features) + + override val optionalFeatures: Set[Feature[_]] = + defaultCondition.optionalFeatures ++ alternateConditions.values.flatMap(_.optionalFeatures) + + override val conditions: Seq[Condition] = Seq(defaultCondition) ++ alternateConditions.values + + override val params: Set[Param[_]] = + conditions + .collect { + case p: HasParams => p.params + }.flatten.toSet + bucketIdentifierToUseOnDisagreementParam + + override def preFilter( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): PreFilterResult = + if (defaultCondition.preFilter(evaluationContext, featureMap) == Filtered && + alternateConditions.values.forall(_.preFilter(evaluationContext, featureMap) == Filtered)) { + Filtered + } else if (defaultCondition.preFilter(evaluationContext, featureMap) == NotFiltered && + alternateConditions.values.forall( + _.preFilter(evaluationContext, featureMap) == NotFiltered)) { + NotFiltered + } else { + NeedsFullEvaluation + } + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val defaultConditionResult: Result = defaultCondition(evaluationContext, featureMap) + val alternateConditionResult: Map[E#Value, Result] = + alternateConditions.mapValues(_(evaluationContext, featureMap)) + + if (alternateConditionResult.values.exists(_.asBoolean != defaultConditionResult.asBoolean)) { + evaluationContext.params(bucketIdentifierToUseOnDisagreementParam) match { + case bucket if alternateConditionResult.contains(bucket) => + alternateConditionResult(bucket) + case _ => + defaultConditionResult + } + } else { + defaultConditionResult + } + } + } + + case object IsTestTweet extends Condition { + override val features: Set[Feature[_]] = Set(TweetId) + override val optionalFeatures: Set[Feature[_]] = Set.empty + + private val UnsatisfiedResult = Unsatisfied(this) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + if (!featureMap.contains(TweetId)) { + UnsatisfiedResult + } else { + Result.SatisfiedResult + } + } + } + + case object IsTweetInTweetLevelStcmHoldback extends Condition { + override val features: Set[Feature[_]] = Set(TweetId) + override val optionalFeatures: Set[Feature[_]] = Set.empty + private val UnsatisfiedResult = Unsatisfied(this) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val tweetId: Long = featureMap(TweetId).asInstanceOf[Long] + if (StcmTweetHoldback.isTweetInHoldback(tweetId)) { + Result.SatisfiedResult + } else { + UnsatisfiedResult + } + } + } + + case object MediaRestrictedInViewerCountry extends Condition { + override val features: Set[Feature[_]] = + Set(MediaGeoRestrictionsAllowList, MediaGeoRestrictionsDenyList) + override val optionalFeatures: Set[Feature[_]] = Set(RequestCountryCode) + + private val UnsatisfiedResult = Unsatisfied(this) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val requestCountryCode = TakedownReasons.normalizeCountryCodeOption( + featureMap.get(RequestCountryCode).asInstanceOf[Option[String]]) + val allowlistCountryCodes = + featureMap(MediaGeoRestrictionsAllowList).asInstanceOf[Seq[String]] + val denylistCountryCodes = + featureMap(MediaGeoRestrictionsDenyList).asInstanceOf[Seq[String]] + if ((allowlistCountryCodes.nonEmpty && !allowlistCountryCodes.contains(requestCountryCode)) + || denylistCountryCodes.contains(requestCountryCode)) { + Result.SatisfiedResult + } else { + UnsatisfiedResult + } + } + } + + case object OneToOneDmConversation + extends BooleanFeatureCondition(DmConversationIsOneToOneConversation) + + case object DmConversationTimelineIsEmpty + extends BooleanFeatureCondition(DmConversationHasEmptyTimeline) + + case object DmConversationLastReadableEventIdIsValid + extends BooleanFeatureCondition(DmConversationHasValidLastReadableEventId) + + case object ViewerIsDmConversationParticipant + extends BooleanFeatureCondition(feats.ViewerIsDmConversationParticipant) + + case object DmConversationInfoExists + extends BooleanFeatureCondition(feats.DmConversationInfoExists) + + case object DmConversationTimelineExists + extends BooleanFeatureCondition(feats.DmConversationTimelineExists) + + case object DmEventIsBeforeLastClearedEvent + extends BooleanFeatureCondition(DmEventOccurredBeforeLastClearedEvent) + + case object DmEventIsBeforeJoinConversationEvent + extends BooleanFeatureCondition(DmEventOccurredBeforeJoinConversationEvent) + + case object DmEventIsDeleted extends BooleanFeatureCondition(feats.DmEventIsDeleted) + + case object DmEventIsHidden extends BooleanFeatureCondition(feats.DmEventIsHidden) + + case object ViewerIsDmEventInitiatingUser + extends BooleanFeatureCondition(feats.ViewerIsDmEventInitiatingUser) + + case object DmEventInOneToOneConversationWithUnavailableUser + extends BooleanFeatureCondition(feats.DmEventInOneToOneConversationWithUnavailableUser) + + case object DmEventInOneToOneConversation + extends BooleanFeatureCondition(feats.DmEventInOneToOneConversation) + + case object MessageCreateDmEvent extends BooleanFeatureCondition(DmEventIsMessageCreateEvent) + + case object WelcomeMessageCreateDmEvent + extends BooleanFeatureCondition(DmEventIsWelcomeMessageCreateEvent) + + case object LastMessageReadUpdateDmEvent + extends BooleanFeatureCondition(DmEventIsLastMessageReadUpdateEvent) + + case object JoinConversationDmEvent + extends BooleanFeatureCondition(DmEventIsJoinConversationEvent) + + case object ConversationCreateDmEvent + extends BooleanFeatureCondition(DmEventIsConversationCreateEvent) + + case object TrustConversationDmEvent + extends BooleanFeatureCondition(DmEventIsTrustConversationEvent) + + case object CsFeedbackSubmittedDmEvent + extends BooleanFeatureCondition(DmEventIsCsFeedbackSubmitted) + + case object CsFeedbackDismissedDmEvent + extends BooleanFeatureCondition(DmEventIsCsFeedbackDismissed) + + case object PerspectivalJoinConversationDmEvent + extends BooleanFeatureCondition(feats.DmEventIsPerspectivalJoinConversationEvent) + + + case class SpaceHasLabelWithScoreAboveThresholdWithParam( + spaceSafetyLabelType: SpaceSafetyLabelType, + thresholdParam: Param[Double]) + extends Condition + with HasParams { + override lazy val name: String = + s"SpaceHasLabelWithScoreAboveThreshold(${spaceSafetyLabelType.name}, ${NamingUtils.getFriendlyName(thresholdParam)})" + override val features: Set[Feature[_]] = Set(SpaceSafetyLabels) + override val optionalFeatures: Set[Feature[_]] = Set.empty + private val UnsatisfiedResult = Unsatisfied(this) + override val params: Set[Param[_]] = Set(thresholdParam) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val labels = featureMap(SpaceSafetyLabels).asInstanceOf[Seq[SpaceSafetyLabel]] + val threshold = evaluationContext.params(thresholdParam) + val SatisfiedResult = + Satisfied(FoundSpaceLabelWithScoreAboveThreshold(spaceSafetyLabelType, threshold)) + labels + .collectFirst { + case label + if label.safetyLabelType == spaceSafetyLabelType + && label.safetyLabel.score.exists(_ >= threshold) => + SatisfiedResult + }.getOrElse(UnsatisfiedResult) + } + } + + case class CardUriHasRootDomain(rootDomainParam: Param[Seq[String]]) + extends Condition + with HasParams { + override lazy val name: String = + s"CardUriHasRootDomain(${NamingUtils.getFriendlyName(rootDomainParam)})" + override val features: Set[Feature[_]] = Set(CardUriHost) + override val optionalFeatures: Set[Feature[_]] = Set.empty + private val UnsatisfiedResult = Unsatisfied(this) + override val params: Set[Param[_]] = Set(rootDomainParam) + + private[this] def isHostDomainOrSubdomain(domain: String, host: String): Boolean = + host == domain || host.endsWith("." + domain) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val cardUriHost = featureMap(CardUriHost).asInstanceOf[String] + val rootDomains = evaluationContext.params(rootDomainParam) + + if (rootDomains.exists(isHostDomainOrSubdomain(_, cardUriHost))) { + Satisfied(FoundCardUriRootDomain(cardUriHost)) + } else { + UnsatisfiedResult + } + } + } + + case class TweetHasViolationOfLevel(level: ViolationLevel) + extends Condition + with HasSafetyLabelType { + + override lazy val name: String = s"tweetHasViolationOfLevel(${level})" + + override val features: Set[Feature[_]] = Set(TweetSafetyLabels) + override def optionalFeatures: Set[Feature[_]] = Set.empty + + private val UnsatisfiedResult: Unsatisfied = Unsatisfied(this) + + private val SatisfiedResult: Satisfied = Satisfied(FoundTweetViolationOfLevel(level)) + + override val labelTypes: Set[SafetyLabelType] = + ViolationLevel.violationLevelToSafetyLabels + .getOrElse(level, Set.empty) + .map(_.asInstanceOf[SafetyLabelType]) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val labels = featureMap(TweetSafetyLabels).asInstanceOf[Seq[TweetSafetyLabel]] + if (labels.map(ViolationLevel.fromTweetSafetyLabel).contains(level)) { + SatisfiedResult + } else { + UnsatisfiedResult + } + } + } + + case object TweetHasViolationOfAnyLevel extends Condition with HasSafetyLabelType { + + override lazy val name: String = s"tweetHasViolationOfAnyLevel" + + override val features: Set[Feature[_]] = Set(TweetSafetyLabels) + + override def optionalFeatures: Set[Feature[_]] = Set.empty + + private val UnsatisfiedResult: Unsatisfied = Unsatisfied(this) + + private val SatisfiedResult: Satisfied = Satisfied(FoundTweetViolationOfSomeLevel) + + override val labelTypes: Set[SafetyLabelType] = + ViolationLevel.violationLevelToSafetyLabels.values + .reduceLeft(_ ++ _) + .map(_.asInstanceOf[SafetyLabelType]) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val labels = featureMap(TweetSafetyLabels).asInstanceOf[Seq[TweetSafetyLabel]] + if (labels + .map(ViolationLevel.fromTweetSafetyLabelOpt).collect { + case Some(level) => level + }.nonEmpty) { + SatisfiedResult + } else { + UnsatisfiedResult + } + } + } + + case object TweetIsEditTweet extends BooleanFeatureCondition(TweetIsEditTweetFeature) + case object TweetIsStaleTweet extends BooleanFeatureCondition(TweetIsStaleTweetFeature) + + + case class ViewerHasAdultMediaSettingLevel(settingLevelToCompare: SensitiveMediaSettingsLevel) + extends Condition { + override def features: Set[Feature[_]] = Set(ViewerSensitiveMediaSettings) + + override def optionalFeatures: Set[Feature[_]] = Set.empty + + private val UnsatisfiedResult = Unsatisfied(this) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + featureMap + .get(ViewerSensitiveMediaSettings) + .map(_.asInstanceOf[UserSensitiveMediaSettings]) + .collectFirst { + case UserSensitiveMediaSettings(Some(setting)) + if (setting.viewAdultContent == settingLevelToCompare) => + Result.SatisfiedResult + case UserSensitiveMediaSettings(None) => UnsatisfiedResult + }.getOrElse(UnsatisfiedResult) + } + } + + case class ViewerHasViolentMediaSettingLevel(settingLevelToCompare: SensitiveMediaSettingsLevel) + extends Condition { + override def features: Set[Feature[_]] = Set(ViewerSensitiveMediaSettings) + + override def optionalFeatures: Set[Feature[_]] = Set.empty + + private val UnsatisfiedResult = Unsatisfied(this) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + + featureMap + .get(ViewerSensitiveMediaSettings) + .map(_.asInstanceOf[UserSensitiveMediaSettings]) + .collectFirst { + case UserSensitiveMediaSettings(Some(setting)) + if (setting.viewViolentContent == settingLevelToCompare) => + Result.SatisfiedResult + case UserSensitiveMediaSettings(None) => UnsatisfiedResult + }.getOrElse(UnsatisfiedResult) + } + } + + case class ViewerHasOtherSensitiveMediaSettingLevel( + settingLevelToCompare: SensitiveMediaSettingsLevel) + extends Condition { + override def features: Set[Feature[_]] = Set(ViewerSensitiveMediaSettings) + + override def optionalFeatures: Set[Feature[_]] = Set.empty + + private val UnsatisfiedResult = Unsatisfied(this) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + + featureMap + .get(ViewerSensitiveMediaSettings) + .map(_.asInstanceOf[UserSensitiveMediaSettings]) + .collectFirst { + case UserSensitiveMediaSettings(Some(setting)) + if (setting.viewOtherContent == settingLevelToCompare) => + Result.SatisfiedResult + case UserSensitiveMediaSettings(None) => UnsatisfiedResult + }.getOrElse(UnsatisfiedResult) + } + } + + private[rules] val ToxrfTweetFilteredForAuthor = + Equals(ToxicReplyFilterState, FilterState.FilteredFromAuthor) + + private[rules] case object ToxrfViewerIsConversationAuthor + extends BooleanFeatureCondition(ToxicReplyFilterConversationAuthorIsViewer) + + val ToxrfFilteredFromAuthorViewer = + And(LoggedInViewer, ToxrfTweetFilteredForAuthor, ToxrfViewerIsConversationAuthor) + + case object SearchQueryMatchesScreenName extends Condition { + override def features: Set[Feature[_]] = Set.empty + + override def optionalFeatures: Set[Feature[_]] = Set(RawQuery, AuthorScreenName) + + private val UnsatisfiedResult = Unsatisfied(this) + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + if (featureMap.contains(RawQuery) && featureMap.contains(AuthorScreenName)) { + val rawQuery = featureMap(RawQuery).asInstanceOf[String] + val authorScreenName = featureMap(AuthorScreenName).asInstanceOf[String] + if (rawQuery.equalsIgnoreCase(authorScreenName)) { + Result.SatisfiedResult + } else { + UnsatisfiedResult + } + } else { + UnsatisfiedResult + } + } + } + + object SearchQueryDoesNotMatchScreenNameConditionBuilder { + def apply(condition: Condition, ruleParam: RuleParam[Boolean]): Choose[Boolean] = { + Choose( + conditionMap = + Map(true -> And(condition, Not(SearchQueryMatchesScreenName)), false -> condition), + defaultCondition = condition, + choiceParam = ruleParam + ) + } + } + + val SearchQueryDoesNotMatchScreenNameDefaultTrueCondition: Choose[Boolean] = + SearchQueryDoesNotMatchScreenNameConditionBuilder(Condition.True, RuleParams.False) + + case object OptionalNonAuthorViewer extends Condition { + override val features: Set[Feature[_]] = Set() + override val optionalFeatures: Set[Feature[_]] = Set(AuthorId, ViewerId) + + private val UnsatisfiedResult = Unsatisfied(this) + + override def preFilter( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): PreFilterResult = { + NeedsFullEvaluation + } + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val authorIdsOpt = featureMap.get(AuthorId).asInstanceOf[Option[Set[Long]]] + val viewerIdOpt = featureMap.get(ViewerId).asInstanceOf[Option[Long]] + + (for { + authorIds <- authorIdsOpt + viewerId <- viewerIdOpt + } yield { + if (authorIds.contains(viewerId)) UnsatisfiedResult + else Result.SatisfiedResult + }) getOrElse { + Result.SatisfiedResult + } + } + } + + case class ViewerLocatedInApplicableCountriesOfMediaWithholdingLabel( + safetyLabelType: MediaSafetyLabelType) + extends ViewerInJurisdiction + with HasSafetyLabelType { + + override lazy val name: String = + s"ViewerLocatedInApplicableCountriesOfMediaLabel(${safetyLabelType.name})" + override val features: Set[Feature[_]] = Set(MediaSafetyLabels) + override val optionalFeatures: Set[Feature[_]] = Set(ViewerCountryCode, RequestCountryCode) + override val labelTypes: Set[SafetyLabelType] = Set(safetyLabelType) + + private val UnsatisfiedResult: Unsatisfied = Unsatisfied(this) + + private[this] def isInApplicableCountries( + countryCodeOpt: Option[String], + label: SafetyLabel + ): Boolean = { + val inApplicableCountry = (for { + applicableCountries <- label.applicableCountries + countryCode <- countryCodeOpt + } yield { + applicableCountries.contains(countryCode) + }) getOrElse (false) + inApplicableCountry + } + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val labels = featureMap(MediaSafetyLabels).asInstanceOf[Seq[MediaSafetyLabel]] + + val countryFeatures = getCountryFeatures(featureMap) + val countryCodeOpt = countryFeatures.requestCountryCode + .orElse(countryFeatures.viewerCountryCode) + + labels + .collectFirst { + case label + if label.safetyLabelType == safetyLabelType + && + isInApplicableCountries(countryCodeOpt, label.safetyLabel) => + Result.SatisfiedResult + }.getOrElse(UnsatisfiedResult) + } + } + + case class MediaHasLabelWithWorldwideWithholding(safetyLabelType: MediaSafetyLabelType) + extends Condition + with HasSafetyLabelType { + + override lazy val name: String = + s"MediaHasLabelWithWorldwideWithholding(${safetyLabelType.name})" + + override val features: Set[Feature[_]] = Set(MediaSafetyLabels) + + override val optionalFeatures: Set[Feature[_]] = Set.empty + + override val labelTypes: Set[SafetyLabelType] = Set(safetyLabelType) + + private val UnsatisfiedResult: Unsatisfied = Unsatisfied(this) + + private[this] def isWithheldWorldwide(label: SafetyLabel): Boolean = { + label.applicableCountries.map(_.contains("xx")).getOrElse(false) + } + + override def apply( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): Result = { + val labels = featureMap(MediaSafetyLabels).asInstanceOf[Seq[MediaSafetyLabel]] + + labels + .collectFirst { + case label + if label.safetyLabelType == safetyLabelType + && isWithheldWorldwide(label.safetyLabel) => + Result.SatisfiedResult + }.getOrElse(UnsatisfiedResult) + } + + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/DmConversationRules.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/DmConversationRules.scala new file mode 100644 index 000000000..b9453cbe0 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/DmConversationRules.scala @@ -0,0 +1,50 @@ +package com.twitter.visibility.rules + +import com.twitter.visibility.configapi.params.RuleParams +import com.twitter.visibility.rules.Condition.And +import com.twitter.visibility.rules.Condition.DmConversationLastReadableEventIdIsValid +import com.twitter.visibility.rules.Condition.DmConversationTimelineIsEmpty +import com.twitter.visibility.rules.Condition.ViewerIsDmConversationParticipant +import com.twitter.visibility.rules.Condition.DmConversationInfoExists +import com.twitter.visibility.rules.Condition.DmConversationTimelineExists +import com.twitter.visibility.rules.Condition.Not +import com.twitter.visibility.rules.Condition.DeactivatedAuthor +import com.twitter.visibility.rules.Condition.ErasedAuthor +import com.twitter.visibility.rules.Condition.OneToOneDmConversation +import com.twitter.visibility.rules.Condition.Or +import com.twitter.visibility.rules.Condition.SuspendedAuthor +import com.twitter.visibility.rules.Reason.Unspecified + +object DmConversationRules { + + object DropEmptyDmConversationRule + extends RuleWithConstantAction( + Drop(Unspecified), + Or( + Not(DmConversationLastReadableEventIdIsValid), + And(OneToOneDmConversation, DmConversationTimelineIsEmpty))) { + override def enableFailClosed = Seq(RuleParams.True) + } + + object DropInaccessibleDmConversationRule + extends RuleWithConstantAction(Drop(Unspecified), Not(ViewerIsDmConversationParticipant)) { + override def enableFailClosed = Seq(RuleParams.True) + } + + object DropDmConversationWithUndefinedConversationInfoRule + extends RuleWithConstantAction(Drop(Unspecified), Not(DmConversationInfoExists)) { + override def enableFailClosed = Seq(RuleParams.True) + } + + object DropDmConversationWithUndefinedConversationTimelineRule + extends RuleWithConstantAction(Drop(Unspecified), Not(DmConversationTimelineExists)) { + override def enableFailClosed = Seq(RuleParams.True) + } + + object DropOneToOneDmConversationWithUnavailableParticipantsRule + extends RuleWithConstantAction( + Drop(Unspecified), + And(OneToOneDmConversation, Or(SuspendedAuthor, DeactivatedAuthor, ErasedAuthor))) { + override def enableFailClosed = Seq(RuleParams.True) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/DmEventRules.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/DmEventRules.scala new file mode 100644 index 000000000..30bff031c --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/DmEventRules.scala @@ -0,0 +1,90 @@ +package com.twitter.visibility.rules + +import com.twitter.visibility.rules.Reason.Unspecified +import com.twitter.visibility.rules.Condition.DeactivatedAuthor +import com.twitter.visibility.rules.Condition.ErasedAuthor +import com.twitter.visibility.rules.Condition.SuspendedAuthor +import com.twitter.visibility.rules.Condition.DmEventInOneToOneConversationWithUnavailableUser +import com.twitter.visibility.rules.Condition.DmEventIsBeforeLastClearedEvent +import com.twitter.visibility.rules.Condition.DmEventIsBeforeJoinConversationEvent +import com.twitter.visibility.rules.Condition.DmEventIsDeleted +import com.twitter.visibility.rules.Condition.DmEventIsHidden +import com.twitter.visibility.rules.Condition.LastMessageReadUpdateDmEvent +import com.twitter.visibility.rules.Condition.MessageCreateDmEvent +import com.twitter.visibility.rules.Condition.PerspectivalJoinConversationDmEvent +import com.twitter.visibility.rules.Condition.ViewerIsDmEventInitiatingUser +import com.twitter.visibility.rules.Condition.ViewerIsDmConversationParticipant +import com.twitter.visibility.configapi.params.RuleParams +import com.twitter.visibility.rules.Condition.And +import com.twitter.visibility.rules.Condition.CsFeedbackDismissedDmEvent +import com.twitter.visibility.rules.Condition.CsFeedbackSubmittedDmEvent +import com.twitter.visibility.rules.Condition.JoinConversationDmEvent +import com.twitter.visibility.rules.Condition.Not +import com.twitter.visibility.rules.Condition.Or +import com.twitter.visibility.rules.Condition.TrustConversationDmEvent +import com.twitter.visibility.rules.Condition.WelcomeMessageCreateDmEvent +import com.twitter.visibility.rules.Condition.DmEventInOneToOneConversation +import com.twitter.visibility.rules.Condition.ConversationCreateDmEvent + +object DmEventRules { + + object MessageCreateEventWithUnavailableSenderDropRule + extends RuleWithConstantAction( + Drop(Unspecified), + Or(SuspendedAuthor, DeactivatedAuthor, ErasedAuthor)) { + override def enableFailClosed = Seq(RuleParams.True) + } + + object WelcomeMessageCreateEventOnlyVisibleToRecipientDropRule + extends RuleWithConstantAction( + Drop(Unspecified), + And(ViewerIsDmEventInitiatingUser, WelcomeMessageCreateDmEvent)) { + override def enableFailClosed = Seq(RuleParams.True) + } + + object InaccessibleDmEventDropRule + extends RuleWithConstantAction( + Drop(Unspecified), + Or( + Not(ViewerIsDmConversationParticipant), + DmEventIsBeforeLastClearedEvent, + DmEventIsBeforeJoinConversationEvent)) { + override def enableFailClosed = Seq(RuleParams.True) + } + + object HiddenAndDeletedDmEventDropRule + extends RuleWithConstantAction(Drop(Unspecified), Or(DmEventIsDeleted, DmEventIsHidden)) { + override def enableFailClosed = Seq(RuleParams.True) + } + + object NonPerspectivalDmEventDropRule + extends RuleWithConstantAction( + Drop(Unspecified), + Or( + And(Not(PerspectivalJoinConversationDmEvent), JoinConversationDmEvent), + And( + Not(ViewerIsDmEventInitiatingUser), + Or(TrustConversationDmEvent, CsFeedbackSubmittedDmEvent, CsFeedbackDismissedDmEvent)) + ) + ) { + override def enableFailClosed = Seq(RuleParams.True) + } + + object DmEventInOneToOneConversationWithUnavailableUserDropRule + extends RuleWithConstantAction( + Drop(Unspecified), + And( + Or(MessageCreateDmEvent, LastMessageReadUpdateDmEvent), + DmEventInOneToOneConversationWithUnavailableUser)) { + override def enableFailClosed = Seq(RuleParams.True) + } + + object GroupEventInOneToOneConversationDropRule + extends RuleWithConstantAction( + Drop(Unspecified), + And( + Or(JoinConversationDmEvent, ConversationCreateDmEvent), + DmEventInOneToOneConversation)) { + override def enableFailClosed = Seq(RuleParams.True) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/DmVisibilityPolicies.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/DmVisibilityPolicies.scala new file mode 100644 index 000000000..c32ba6d84 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/DmVisibilityPolicies.scala @@ -0,0 +1,130 @@ +package com.twitter.visibility.rules + +import com.twitter.visibility.configapi.params.RuleParams +import com.twitter.visibility.rules.DmConversationRules._ +import com.twitter.visibility.rules.DmEventRules._ +import com.twitter.visibility.rules.PolicyLevelRuleParams.ruleParams + +object SensitiveMediaSettingsDirectMessagesBaseRules { + val policyRuleParams = Map[Rule, PolicyLevelRuleParams]( + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaDirectMessagesRulesParam), + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaDirectMessagesRulesParam), + NsfwReportedHeuristicsAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaDirectMessagesRulesParam), + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaDirectMessagesRulesParam), + NsfwCardImageAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaDirectMessagesRulesParam) + ) +} + +case object DirectMessagesPolicy + extends VisibilityPolicy( + tweetRules = TweetDetailPolicy.tweetRules.diff(LimitedEngagementBaseRules.tweetRules), + dmRules = Seq( + DeactivatedAuthorRule, + ErasedAuthorRule + ), + policyRuleParams = SensitiveMediaSettingsDirectMessagesBaseRules.policyRuleParams + ) + +case object DirectMessagesMutedUsersPolicy + extends VisibilityPolicy( + userRules = Seq(SuspendedAuthorRule) + ) + +case object DirectMessagesSearchPolicy + extends VisibilityPolicy( + dmConversationRules = Seq( + DropDmConversationWithUndefinedConversationInfoRule, + DropDmConversationWithUndefinedConversationTimelineRule, + DropInaccessibleDmConversationRule, + DropEmptyDmConversationRule, + DropOneToOneDmConversationWithUnavailableParticipantsRule + ), + dmEventRules = Seq( + InaccessibleDmEventDropRule, + HiddenAndDeletedDmEventDropRule, + MessageCreateEventWithUnavailableSenderDropRule), + userRules = Seq(ErasedAuthorRule, DeactivatedAuthorRule, SuspendedAuthorRule), + tweetRules = + Seq(ViewerBlocksAuthorRule, ViewerMutesAuthorRule) ++ TweetDetailPolicy.tweetRules.diff( + LimitedEngagementBaseRules.tweetRules), + policyRuleParams = SensitiveMediaSettingsDirectMessagesBaseRules.policyRuleParams + ) + +case object DirectMessagesPinnedPolicy + extends VisibilityPolicy( + dmConversationRules = Seq( + DropDmConversationWithUndefinedConversationInfoRule, + DropDmConversationWithUndefinedConversationTimelineRule, + DropInaccessibleDmConversationRule, + DropEmptyDmConversationRule, + DropOneToOneDmConversationWithUnavailableParticipantsRule + ), + dmEventRules = Seq( + InaccessibleDmEventDropRule, + HiddenAndDeletedDmEventDropRule, + MessageCreateEventWithUnavailableSenderDropRule), + userRules = Seq(ErasedAuthorRule, DeactivatedAuthorRule, SuspendedAuthorRule), + tweetRules = + Seq(ViewerBlocksAuthorRule, ViewerMutesAuthorRule) ++ TweetDetailPolicy.tweetRules.diff( + LimitedEngagementBaseRules.tweetRules), + policyRuleParams = SensitiveMediaSettingsDirectMessagesBaseRules.policyRuleParams + ) + +case object DirectMessagesConversationListPolicy + extends VisibilityPolicy( + dmConversationRules = Seq( + DropDmConversationWithUndefinedConversationInfoRule, + DropDmConversationWithUndefinedConversationTimelineRule, + DropInaccessibleDmConversationRule, + DropEmptyDmConversationRule, + DropOneToOneDmConversationWithUnavailableParticipantsRule + ), + userRules = Seq(ErasedAuthorRule, DeactivatedAuthorRule, SuspendedAuthorRule), + tweetRules = + Seq(ViewerBlocksAuthorRule, ViewerMutesAuthorRule) ++ TweetDetailPolicy.tweetRules.diff( + LimitedEngagementBaseRules.tweetRules), + policyRuleParams = SensitiveMediaSettingsDirectMessagesBaseRules.policyRuleParams + ) + +case object DirectMessagesConversationTimelinePolicy + extends VisibilityPolicy( + dmEventRules = Seq( + InaccessibleDmEventDropRule, + HiddenAndDeletedDmEventDropRule, + MessageCreateEventWithUnavailableSenderDropRule), + userRules = Seq(ErasedAuthorRule, DeactivatedAuthorRule, SuspendedAuthorRule), + tweetRules = + Seq(ViewerBlocksAuthorRule, ViewerMutesAuthorRule) ++ TweetDetailPolicy.tweetRules.diff( + LimitedEngagementBaseRules.tweetRules), + policyRuleParams = SensitiveMediaSettingsDirectMessagesBaseRules.policyRuleParams + ) + +case object DirectMessagesInboxPolicy + extends VisibilityPolicy( + dmConversationRules = Seq( + DropDmConversationWithUndefinedConversationInfoRule, + DropDmConversationWithUndefinedConversationTimelineRule, + DropInaccessibleDmConversationRule, + DropEmptyDmConversationRule, + DropOneToOneDmConversationWithUnavailableParticipantsRule + ), + dmEventRules = Seq( + InaccessibleDmEventDropRule, + HiddenAndDeletedDmEventDropRule, + DmEventInOneToOneConversationWithUnavailableUserDropRule, + MessageCreateEventWithUnavailableSenderDropRule, + NonPerspectivalDmEventDropRule, + WelcomeMessageCreateEventOnlyVisibleToRecipientDropRule, + GroupEventInOneToOneConversationDropRule + ), + userRules = Seq(ErasedAuthorRule, DeactivatedAuthorRule, SuspendedAuthorRule), + tweetRules = + Seq(ViewerBlocksAuthorRule, ViewerMutesAuthorRule) ++ TweetDetailPolicy.tweetRules.diff( + LimitedEngagementBaseRules.tweetRules), + policyRuleParams = SensitiveMediaSettingsDirectMessagesBaseRules.policyRuleParams + ) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/DownrankingRules.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/DownrankingRules.scala new file mode 100644 index 000000000..7126b8444 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/DownrankingRules.scala @@ -0,0 +1,207 @@ +package com.twitter.visibility.rules + +import com.twitter.timelines.configapi.Params +import com.twitter.visibility.common.ModelScoreThresholds +import com.twitter.visibility.configapi.configs.DeciderKey +import com.twitter.visibility.configapi.params.FSRuleParams.HighSpammyTweetContentScoreConvoDownrankAbusiveQualityThresholdParam +import com.twitter.visibility.configapi.params.RuleParams.EnableDownrankSpamReplySectioningRuleParam +import com.twitter.visibility.configapi.params.RuleParams.EnableNotGraduatedDownrankConvosAbusiveQualityRuleParam +import com.twitter.visibility.configapi.params.RuleParams.NotGraduatedUserLabelRuleHoldbackExperimentParam +import com.twitter.visibility.configapi.params.TimelineConversationsDownrankingSpecificParams._ +import com.twitter.visibility.configapi.params.RuleParam +import com.twitter.visibility.models.TweetSafetyLabelType +import com.twitter.visibility.models.UserLabelValue +import com.twitter.visibility.rules.Condition._ +import com.twitter.visibility.rules.RuleActionSourceBuilder.TweetSafetyLabelSourceBuilder +import com.twitter.visibility.rules.RuleActionSourceBuilder.UserSafetyLabelSourceBuilder + +object DownrankingRules { + + val ToxicityScoreAboveDownrankAbusiveQualitySectionThresholdCondition: TweetHasLabelWithLanguageScoreAboveThreshold = + TweetHasLabelWithLanguageScoreAboveThreshold( + safetyLabel = TweetSafetyLabelType.HighToxicityScore, + languagesToScoreThresholds = ModelScoreThresholds.ToxicityAbusiveQualityLanguagesToThresholds + ) + + val ToxicityScoreAboveDownrankLowQualitySectionThresholdCondition: TweetHasLabelWithLanguageScoreAboveThreshold = + TweetHasLabelWithLanguageScoreAboveThreshold( + safetyLabel = TweetSafetyLabelType.HighToxicityScore, + languagesToScoreThresholds = ModelScoreThresholds.ToxicityLowQualityLanguagesToThresholds + ) + + val ToxicityScoreAboveDownrankHighQualitySectionThresholdCondition: TweetHasLabelWithLanguageScoreAboveThreshold = + TweetHasLabelWithLanguageScoreAboveThreshold( + safetyLabel = TweetSafetyLabelType.HighToxicityScore, + languagesToScoreThresholds = ModelScoreThresholds.ToxicityHighQualityLanguagesToThresholds + ) + + val HighSpammyTweetContentScoreConvoDownrankAbusiveQualityCondition: Condition = + TweetHasLabelWithScoreAboveThresholdWithParam( + TweetSafetyLabelType.HighSpammyTweetContentScore, + HighSpammyTweetContentScoreConvoDownrankAbusiveQualityThresholdParam) + + val HighCryptospamScoreConvoDownrankAbusiveQualityCondition: Condition = + TweetHasLabel(TweetSafetyLabelType.HighCryptospamScore) +} + +object HighToxicityScoreDownrankHighQualitySectionRule + extends ConditionWithNotInnerCircleOfFriendsRule( + Downrank, + DownrankingRules.ToxicityScoreAboveDownrankHighQualitySectionThresholdCondition + ) + with DoesLogVerdictDecidered { + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(TweetSafetyLabelType.HighToxicityScore)) + + override def verdictLogDeciderKey = DeciderKey.EnableDownlevelRuleVerdictLogging +} + +object HighToxicityScoreDownrankLowQualitySectionRule + extends ConditionWithNotInnerCircleOfFriendsRule( + ConversationSectionLowQuality, + DownrankingRules.ToxicityScoreAboveDownrankLowQualitySectionThresholdCondition + ) + with DoesLogVerdict { + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(TweetSafetyLabelType.HighToxicityScore)) +} + +object HighToxicityScoreDownrankAbusiveQualitySectionRule + extends ConditionWithNotInnerCircleOfFriendsRule( + ConversationSectionAbusiveQuality, + DownrankingRules.ToxicityScoreAboveDownrankAbusiveQualitySectionThresholdCondition + ) + with DoesLogVerdict { + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(TweetSafetyLabelType.HighToxicityScore)) +} + +object UntrustedUrlConversationsTweetLabelRule + extends ConditionWithNotInnerCircleOfFriendsRule( + ConversationSectionAbusiveQuality, + TweetHasLabel(TweetSafetyLabelType.UntrustedUrl) + ) + with DoesLogVerdictDecidered { + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(TweetSafetyLabelType.UntrustedUrl)) + override def verdictLogDeciderKey = DeciderKey.EnableDownlevelRuleVerdictLogging +} + +object DownrankSpamReplyConversationsTweetLabelRule + extends ConditionWithNotInnerCircleOfFriendsRule( + ConversationSectionAbusiveQuality, + TweetHasLabel(TweetSafetyLabelType.DownrankSpamReply) + ) + with DoesLogVerdictDecidered { + + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableDownrankSpamReplySectioningRuleParam) + + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(TweetSafetyLabelType.DownrankSpamReply)) + override def verdictLogDeciderKey = DeciderKey.EnableDownlevelRuleVerdictLogging +} + +object DownrankSpamReplyConversationsAuthorLabelRule + extends AuthorLabelWithNotInnerCircleOfFriendsRule( + ConversationSectionAbusiveQuality, + UserLabelValue.DownrankSpamReply + ) + with DoesLogVerdictDecidered { + + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableDownrankSpamReplySectioningRuleParam) + override def verdictLogDeciderKey = DeciderKey.EnableDownlevelRuleVerdictLogging + + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + UserSafetyLabelSourceBuilder(UserLabelValue.DownrankSpamReply)) +} + +object NotGraduatedConversationsAuthorLabelRule + extends AuthorLabelWithNotInnerCircleOfFriendsRule( + ConversationSectionLowQuality, + UserLabelValue.NotGraduated + ) + with DoesLogVerdictDecidered { + + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableNotGraduatedDownrankConvosAbusiveQualityRuleParam) + + override def holdbacks: Seq[RuleParam[Boolean]] = Seq( + NotGraduatedUserLabelRuleHoldbackExperimentParam) + + override def verdictLogDeciderKey = DeciderKey.EnableDownlevelRuleVerdictLogging + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + UserSafetyLabelSourceBuilder(UserLabelValue.NotGraduated)) +} + +object HighProactiveTosScoreTweetLabelDownrankingRule + extends ConditionWithNotInnerCircleOfFriendsRule( + ConversationSectionAbusiveQuality, + TweetHasLabel(TweetSafetyLabelType.HighProactiveTosScore) + ) + with DoesLogVerdictDecidered { + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(TweetSafetyLabelType.HighProactiveTosScore)) + override def verdictLogDeciderKey = DeciderKey.EnableDownlevelRuleVerdictLogging +} + +object HighPSpammyTweetScoreDownrankLowQualitySectionRule + extends ConditionWithNotInnerCircleOfFriendsRule( + action = ConversationSectionLowQuality, + condition = TweetHasLabelWithScoreAboveThreshold( + TweetSafetyLabelType.HighPSpammyTweetScore, + ModelScoreThresholds.HighPSpammyTweetScoreThreshold) + ) + with DoesLogVerdictDecidered { + override def enabled: Seq[RuleParam[Boolean]] = Seq( + EnablePSpammyTweetDownrankConvosLowQualityParam) + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(TweetSafetyLabelType.HighPSpammyTweetScore)) + override def verdictLogDeciderKey: DeciderKey.Value = + DeciderKey.EnableSpammyTweetRuleVerdictLogging +} + +object HighSpammyTweetContentScoreConvoDownrankAbusiveQualityRule + extends ConditionWithNotInnerCircleOfFriendsRule( + action = ConversationSectionAbusiveQuality, + condition = And( + Not(IsTweetInTweetLevelStcmHoldback), + DownrankingRules.HighSpammyTweetContentScoreConvoDownrankAbusiveQualityCondition) + ) + with DoesLogVerdictDecidered { + override def isEnabled(params: Params): Boolean = { + params(EnableHighSpammyTweetContentScoreConvoDownrankAbusiveQualityRuleParam) + } + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(TweetSafetyLabelType.HighSpammyTweetContentScore)) + override def verdictLogDeciderKey = DeciderKey.EnableDownlevelRuleVerdictLogging +} + +object HighCryptospamScoreConvoDownrankAbusiveQualityRule + extends ConditionWithNotInnerCircleOfFriendsRule( + action = ConversationSectionAbusiveQuality, + condition = DownrankingRules.HighCryptospamScoreConvoDownrankAbusiveQualityCondition + ) + with DoesLogVerdictDecidered { + override def isEnabled(params: Params): Boolean = { + params(EnableHighCryptospamScoreConvoDownrankAbusiveQualityRuleParam) + } + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(TweetSafetyLabelType.HighCryptospamScore)) + override def verdictLogDeciderKey = DeciderKey.EnableDownlevelRuleVerdictLogging +} + +object RitoActionedTweetDownrankLowQualitySectionRule + extends ConditionWithNotInnerCircleOfFriendsRule( + action = ConversationSectionLowQuality, + condition = TweetHasLabel(TweetSafetyLabelType.RitoActionedTweet) + ) + with DoesLogVerdictDecidered { + override def enabled: Seq[RuleParam[Boolean]] = Seq( + EnableRitoActionedTweetDownrankConvosLowQualityParam) + + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(TweetSafetyLabelType.RitoActionedTweet)) + override def verdictLogDeciderKey = DeciderKey.EnableDownlevelRuleVerdictLogging +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/EvaluationContext.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/EvaluationContext.scala new file mode 100644 index 000000000..545bce460 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/EvaluationContext.scala @@ -0,0 +1,68 @@ +package com.twitter.visibility.rules + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.servo.util.Gate +import com.twitter.timelines.configapi.HasParams +import com.twitter.timelines.configapi.Params +import com.twitter.visibility.configapi.VisibilityParams +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.UnitOfDiversion +import com.twitter.visibility.models.ViewerContext + +case class EvaluationContext( + visibilityPolicy: VisibilityPolicy, + params: Params, + statsReceiver: StatsReceiver) + extends HasParams { + + def ruleEnabledInContext(rule: Rule): Boolean = { + visibilityPolicy.policyRuleParams + .get(rule) + .filter(_.ruleParams.nonEmpty) + .map(policyRuleParams => { + (policyRuleParams.force || rule.enabled.forall(params(_))) && + policyRuleParams.ruleParams.forall(params(_)) + }) + .getOrElse(rule.isEnabled(params)) + } +} + +object EvaluationContext { + + def apply( + safetyLevel: SafetyLevel, + params: Params, + statsReceiver: StatsReceiver + ): EvaluationContext = { + val visibilityPolicy = RuleBase.RuleMap(safetyLevel) + new EvaluationContext(visibilityPolicy, params, statsReceiver) + } + + case class Builder( + statsReceiver: StatsReceiver, + visibilityParams: VisibilityParams, + viewerContext: ViewerContext, + unitsOfDiversion: Seq[UnitOfDiversion] = Seq.empty, + memoizeParams: Gate[Unit] = Gate.False, + ) { + + private[this] val emptyContentToUoDCounter = + statsReceiver.counter("empty_content_id_to_unit_of_diversion") + + def build(safetyLevel: SafetyLevel): EvaluationContext = { + val policy = RuleBase.RuleMap(safetyLevel) + val params = if (memoizeParams()) { + visibilityParams.memoized(viewerContext, safetyLevel, unitsOfDiversion) + } else { + visibilityParams(viewerContext, safetyLevel, unitsOfDiversion) + } + new EvaluationContext(policy, params, statsReceiver) + } + + def withUnitOfDiversion(unitOfDiversion: UnitOfDiversion*): Builder = + this.copy(unitsOfDiversion = unitOfDiversion) + + def withMemoizedParams(memoizeParams: Gate[Unit]) = this.copy(memoizeParams = memoizeParams) + } + +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/ExperimentBase.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/ExperimentBase.scala new file mode 100644 index 000000000..840ca5312 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/ExperimentBase.scala @@ -0,0 +1,18 @@ +package com.twitter.visibility.rules + +import com.twitter.timelines.configapi.Params +import com.twitter.visibility.configapi.params.LabelSourceParam +import com.twitter.visibility.models.LabelSource + +object ExperimentBase { + val sourceToParamMap: Map[LabelSource, LabelSourceParam] = Map.empty + + final def shouldFilterForSource(params: Params, labelSourceOpt: Option[LabelSource]): Boolean = { + labelSourceOpt + .map { source => + val param = ExperimentBase.sourceToParamMap.get(source) + param.map(params.apply).getOrElse(true) + } + .getOrElse(true) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/FailClosedException.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/FailClosedException.scala new file mode 100644 index 000000000..f3f99f4a7 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/FailClosedException.scala @@ -0,0 +1,41 @@ +package com.twitter.visibility.rules + +import com.twitter.visibility.features.Feature +import com.twitter.visibility.rules.State.FeatureFailed +import com.twitter.visibility.rules.State.MissingFeature +import com.twitter.visibility.rules.State.RuleFailed + +abstract class FailClosedException(message: String, state: State, ruleName: String) + extends Exception(message) { + def getState: State = { + state + } + + def getRuleName: String = { + ruleName + } +} + +case class MissingFeaturesException( + ruleName: String, + missingFeatures: Set[Feature[_]]) + extends FailClosedException( + s"A $ruleName rule evaluation has ${missingFeatures.size} missing features: ${missingFeatures + .map(_.name)}", + MissingFeature(missingFeatures), + ruleName) {} + +case class FeaturesFailedException( + ruleName: String, + featureFailures: Map[Feature[_], Throwable]) + extends FailClosedException( + s"A $ruleName rule evaluation has ${featureFailures.size} failed features: ${featureFailures.keys + .map(_.name)}, ${featureFailures.values}", + FeatureFailed(featureFailures), + ruleName) {} + +case class RuleFailedException(ruleName: String, exception: Throwable) + extends FailClosedException( + s"A $ruleName rule evaluation failed to execute", + RuleFailed(exception), + ruleName) {} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/FollowerRelations.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/FollowerRelations.scala new file mode 100644 index 000000000..1492db7fa --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/FollowerRelations.scala @@ -0,0 +1,20 @@ +package com.twitter.visibility.rules + +import com.twitter.visibility.features.AuthorMutesViewer +import com.twitter.visibility.rules.Condition.BooleanFeatureCondition +import com.twitter.visibility.rules.Condition.ProtectedViewer +import com.twitter.visibility.rules.Reason.Unspecified + +object FollowerRelations { + + case object AuthorMutesViewerFeature extends BooleanFeatureCondition(AuthorMutesViewer) + + object AuthorMutesViewerRule + extends OnlyWhenNotAuthorViewerRule( + action = Drop(Unspecified), + condition = AuthorMutesViewerFeature) + + object ProtectedViewerRule + extends OnlyWhenNotAuthorViewerRule(action = Drop(Unspecified), condition = ProtectedViewer) + +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/ForEmergencyUseOnly.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/ForEmergencyUseOnly.scala new file mode 100644 index 000000000..fabffaae2 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/ForEmergencyUseOnly.scala @@ -0,0 +1,100 @@ +package com.twitter.visibility.rules + +import com.twitter.visibility.common.actions.ComplianceTweetNoticeEventType +import com.twitter.visibility.configapi.params.RuleParam +import com.twitter.visibility.configapi.params.RuleParams.EnableSearchIpiSafeSearchWithoutUserInQueryDropRule +import com.twitter.visibility.features.Feature +import com.twitter.visibility.features.TweetSafetyLabels +import com.twitter.visibility.models.LabelSource.StringSource +import com.twitter.visibility.models.LabelSource.parseStringSource +import com.twitter.visibility.models.TweetSafetyLabel +import com.twitter.visibility.models.TweetSafetyLabelType +import com.twitter.visibility.rules.Condition.And +import com.twitter.visibility.rules.Condition.LoggedOutOrViewerOptInFiltering +import com.twitter.visibility.rules.Condition.Not +import com.twitter.visibility.rules.Condition.SearchQueryHasUser +import com.twitter.visibility.rules.Condition.TweetHasLabel +import com.twitter.visibility.rules.Reason.Unspecified + +object EmergencyDynamicInterstitialActionBuilder + extends ActionBuilder[EmergencyDynamicInterstitial] { + + def actionType: Class[_] = classOf[EmergencyDynamicInterstitial] + + override val actionSeverity = 11 + override def build( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): RuleResult = { + val label = featureMap(TweetSafetyLabels) + .asInstanceOf[Seq[TweetSafetyLabel]] + .find(slv => slv.labelType == TweetSafetyLabelType.ForEmergencyUseOnly) + + label.flatMap(_.source) match { + case Some(StringSource(name)) => + val (copy, linkOpt) = parseStringSource(name) + RuleResult(EmergencyDynamicInterstitial(copy, linkOpt), State.Evaluated) + + case _ => + Rule.EvaluatedRuleResult + } + } +} + +object EmergencyDynamicComplianceTweetNoticeActionBuilder + extends ActionBuilder[ComplianceTweetNoticePreEnrichment] { + + def actionType: Class[_] = classOf[ComplianceTweetNoticePreEnrichment] + + override val actionSeverity = 2 + override def build( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): RuleResult = { + val label = featureMap(TweetSafetyLabels) + .asInstanceOf[Seq[TweetSafetyLabel]] + .find(slv => slv.labelType == TweetSafetyLabelType.ForEmergencyUseOnly) + + label.flatMap(_.source) match { + case Some(StringSource(name)) => + val (copy, linkOpt) = parseStringSource(name) + RuleResult( + ComplianceTweetNoticePreEnrichment( + reason = Unspecified, + complianceTweetNoticeEventType = ComplianceTweetNoticeEventType.PublicInterest, + details = Some(copy), + extendedDetailsUrl = linkOpt + ), + State.Evaluated + ) + + case _ => + Rule.EvaluatedRuleResult + } + } +} + +object EmergencyDynamicInterstitialRule + extends Rule( + EmergencyDynamicInterstitialActionBuilder, + TweetHasLabel(TweetSafetyLabelType.ForEmergencyUseOnly) + ) + +object EmergencyDropRule + extends RuleWithConstantAction( + Drop(Unspecified), + TweetHasLabel(TweetSafetyLabelType.ForEmergencyUseOnly) + ) + +object SearchEdiSafeSearchWithoutUserInQueryDropRule + extends RuleWithConstantAction( + Drop(Unspecified), + And( + TweetHasLabel(TweetSafetyLabelType.ForEmergencyUseOnly), + LoggedOutOrViewerOptInFiltering, + Not(SearchQueryHasUser) + ) + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq( + EnableSearchIpiSafeSearchWithoutUserInQueryDropRule) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/FreedomOfSpeechNotReach.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/FreedomOfSpeechNotReach.scala new file mode 100644 index 000000000..ba2861e60 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/FreedomOfSpeechNotReach.scala @@ -0,0 +1,705 @@ +package com.twitter.visibility.rules + +import com.twitter.spam.rtf.thriftscala.SafetyResultReason +import com.twitter.util.Memoize +import com.twitter.visibility.common.actions.AppealableReason +import com.twitter.visibility.common.actions.LimitedEngagementReason +import com.twitter.visibility.common.actions.SoftInterventionDisplayType +import com.twitter.visibility.common.actions.SoftInterventionReason +import com.twitter.visibility.common.actions.LimitedActionsPolicy +import com.twitter.visibility.common.actions.LimitedAction +import com.twitter.visibility.common.actions.converter.scala.LimitedActionTypeConverter +import com.twitter.visibility.configapi.params.FSRuleParams.FosnrFallbackDropRulesEnabledParam +import com.twitter.visibility.configapi.params.FSRuleParams.FosnrRulesEnabledParam +import com.twitter.visibility.configapi.params.RuleParam +import com.twitter.visibility.configapi.params.RuleParams.EnableFosnrRuleParam +import com.twitter.visibility.features.Feature +import com.twitter.visibility.features.TweetSafetyLabels +import com.twitter.visibility.models.TweetSafetyLabel +import com.twitter.visibility.models.TweetSafetyLabelType +import com.twitter.visibility.models.ViolationLevel +import com.twitter.visibility.rules.ComposableActions.ComposableActionsWithInterstitialLimitedEngagements +import com.twitter.visibility.rules.ComposableActions.ComposableActionsWithSoftIntervention +import com.twitter.visibility.rules.ComposableActions.ComposableActionsWithAppealable +import com.twitter.visibility.rules.ComposableActions.ComposableActionsWithInterstitial +import com.twitter.visibility.rules.Condition.And +import com.twitter.visibility.rules.Condition.NonAuthorViewer +import com.twitter.visibility.rules.Condition.Not +import com.twitter.visibility.rules.Condition.ViewerDoesNotFollowAuthorOfFosnrViolatingTweet +import com.twitter.visibility.rules.Condition.ViewerFollowsAuthorOfFosnrViolatingTweet +import com.twitter.visibility.rules.FreedomOfSpeechNotReach.DefaultViolationLevel +import com.twitter.visibility.rules.Reason._ +import com.twitter.visibility.rules.State.Evaluated + +object FreedomOfSpeechNotReach { + + val DefaultViolationLevel = ViolationLevel.Level1 + + def reasonToSafetyResultReason(reason: Reason): SafetyResultReason = + reason match { + case HatefulConduct => SafetyResultReason.FosnrHatefulConduct + case AbusiveBehavior => SafetyResultReason.FosnrAbusiveBehavior + case _ => SafetyResultReason.FosnrUnspecified + } + + def reasonToSafetyResultReason(reason: AppealableReason): SafetyResultReason = + reason match { + case AppealableReason.HatefulConduct(_) => SafetyResultReason.FosnrHatefulConduct + case AppealableReason.AbusiveBehavior(_) => SafetyResultReason.FosnrAbusiveBehavior + case _ => SafetyResultReason.FosnrUnspecified + } + + val EligibleTweetSafetyLabelTypes: Seq[TweetSafetyLabelType] = + Seq(ViolationLevel.Level4, ViolationLevel.Level3, ViolationLevel.Level2, ViolationLevel.Level1) + .map { + ViolationLevel.violationLevelToSafetyLabels.get(_).getOrElse(Set()).toSeq + }.reduceLeft { + _ ++ _ + } + + private val EligibleTweetSafetyLabelTypesSet = EligibleTweetSafetyLabelTypes.toSet + + def extractTweetSafetyLabel(featureMap: Map[Feature[_], _]): Option[TweetSafetyLabel] = { + val tweetSafetyLabels = featureMap(TweetSafetyLabels) + .asInstanceOf[Seq[TweetSafetyLabel]] + .flatMap { tsl => + if (FreedomOfSpeechNotReach.EligibleTweetSafetyLabelTypesSet.contains(tsl.labelType)) { + Some(tsl.labelType -> tsl) + } else { + None + } + } + .toMap + + FreedomOfSpeechNotReach.EligibleTweetSafetyLabelTypes.flatMap(tweetSafetyLabels.get).headOption + } + + def eligibleTweetSafetyLabelTypesToAppealableReason( + labelType: TweetSafetyLabelType, + violationLevel: ViolationLevel + ): AppealableReason = { + labelType match { + case TweetSafetyLabelType.FosnrHatefulConduct => + AppealableReason.HatefulConduct(violationLevel.level) + case TweetSafetyLabelType.FosnrHatefulConductLowSeveritySlur => + AppealableReason.HatefulConduct(violationLevel.level) + case _ => + AppealableReason.Unspecified(violationLevel.level) + } + } + + def limitedActionConverter( + limitedActionStrings: Option[Seq[String]] + ): Option[LimitedActionsPolicy] = { + val limitedActions = limitedActionStrings.map { limitedActionString => + limitedActionString + .map(action => LimitedActionTypeConverter.fromString(action)).map { action => + action match { + case Some(a) => Some(LimitedAction(a, None)) + case _ => None + } + }.flatten + } + limitedActions.map(actions => LimitedActionsPolicy(actions)) + } +} + +object FreedomOfSpeechNotReachReason { + def unapply(softIntervention: SoftIntervention): Option[AppealableReason] = { + softIntervention.reason match { + case SoftInterventionReason.FosnrReason(appealableReason) => Some(appealableReason) + case _ => None + } + } + + def unapply( + interstitialLimitedEngagements: InterstitialLimitedEngagements + ): Option[AppealableReason] = { + interstitialLimitedEngagements.limitedEngagementReason match { + case Some(LimitedEngagementReason.FosnrReason(appealableReason)) => Some(appealableReason) + case _ => None + } + } + + def unapply( + interstitial: Interstitial + ): Option[AppealableReason] = { + interstitial.reason match { + case Reason.FosnrReason(appealableReason) => Some(appealableReason) + case _ => None + } + } + + def unapply( + appealable: Appealable + ): Option[AppealableReason] = { + Reason.toAppealableReason(appealable.reason, appealable.violationLevel) + } + + def unapply( + action: Action + ): Option[AppealableReason] = { + action match { + case a: SoftIntervention => + a match { + case FreedomOfSpeechNotReachReason(r) => Some(r) + case _ => None + } + case a: InterstitialLimitedEngagements => + a match { + case FreedomOfSpeechNotReachReason(r) => Some(r) + case _ => None + } + case a: Interstitial => + a match { + case FreedomOfSpeechNotReachReason(r) => Some(r) + case _ => None + } + case a: Appealable => + a match { + case FreedomOfSpeechNotReachReason(r) => Some(r) + case _ => None + } + case ComposableActionsWithSoftIntervention(FreedomOfSpeechNotReachReason(appealableReason)) => + Some(appealableReason) + case ComposableActionsWithInterstitialLimitedEngagements( + FreedomOfSpeechNotReachReason(appealableReason)) => + Some(appealableReason) + case ComposableActionsWithInterstitial(FreedomOfSpeechNotReachReason(appealableReason)) => + Some(appealableReason) + case ComposableActionsWithAppealable(FreedomOfSpeechNotReachReason(appealableReason)) => + Some(appealableReason) + case _ => None + } + } +} + +object FreedomOfSpeechNotReachActions { + + trait FreedomOfSpeechNotReachActionBuilder[T <: Action] extends ActionBuilder[T] { + def withViolationLevel(violationLevel: ViolationLevel): FreedomOfSpeechNotReachActionBuilder[T] + } + + case class DropAction(violationLevel: ViolationLevel = DefaultViolationLevel) + extends FreedomOfSpeechNotReachActionBuilder[Drop] { + + override def actionType: Class[_] = classOf[Drop] + + override val actionSeverity = 16 + private def toRuleResult: Reason => RuleResult = Memoize { r => RuleResult(Drop(r), Evaluated) } + + def build(evaluationContext: EvaluationContext, featureMap: Map[Feature[_], _]): RuleResult = { + val appealableReason = + FreedomOfSpeechNotReach.extractTweetSafetyLabel(featureMap).map(_.labelType) match { + case Some(label) => + FreedomOfSpeechNotReach.eligibleTweetSafetyLabelTypesToAppealableReason( + label, + violationLevel) + case _ => + AppealableReason.Unspecified(violationLevel.level) + } + + toRuleResult(Reason.fromAppealableReason(appealableReason)) + } + + override def withViolationLevel(violationLevel: ViolationLevel) = { + copy(violationLevel = violationLevel) + } + } + + case class AppealableAction(violationLevel: ViolationLevel = DefaultViolationLevel) + extends FreedomOfSpeechNotReachActionBuilder[TweetInterstitial] { + + override def actionType: Class[_] = classOf[Appealable] + + override val actionSeverity = 17 + private def toRuleResult: Reason => RuleResult = Memoize { r => + RuleResult( + TweetInterstitial( + interstitial = None, + softIntervention = None, + limitedEngagements = None, + downrank = None, + avoid = Some(Avoid(None)), + mediaInterstitial = None, + tweetVisibilityNudge = None, + abusiveQuality = None, + appealable = Some(Appealable(r, violationLevel = violationLevel)) + ), + Evaluated + ) + } + + def build(evaluationContext: EvaluationContext, featureMap: Map[Feature[_], _]): RuleResult = { + val appealableReason = + FreedomOfSpeechNotReach.extractTweetSafetyLabel(featureMap).map(_.labelType) match { + case Some(label) => + FreedomOfSpeechNotReach.eligibleTweetSafetyLabelTypesToAppealableReason( + label, + violationLevel) + case _ => + AppealableReason.Unspecified(violationLevel.level) + } + + toRuleResult(Reason.fromAppealableReason(appealableReason)) + } + + override def withViolationLevel(violationLevel: ViolationLevel) = { + copy(violationLevel = violationLevel) + } + } + + case class AppealableAvoidLimitedEngagementsAction( + violationLevel: ViolationLevel = DefaultViolationLevel, + limitedActionStrings: Option[Seq[String]]) + extends FreedomOfSpeechNotReachActionBuilder[Appealable] { + + override def actionType: Class[_] = classOf[AppealableAvoidLimitedEngagementsAction] + + override val actionSeverity = 17 + private def toRuleResult: Reason => RuleResult = Memoize { r => + RuleResult( + TweetInterstitial( + interstitial = None, + softIntervention = None, + limitedEngagements = Some( + LimitedEngagements( + toLimitedEngagementReason( + Reason + .toAppealableReason(r, violationLevel) + .getOrElse(AppealableReason.Unspecified(violationLevel.level))), + FreedomOfSpeechNotReach.limitedActionConverter(limitedActionStrings) + )), + downrank = None, + avoid = Some(Avoid(None)), + mediaInterstitial = None, + tweetVisibilityNudge = None, + abusiveQuality = None, + appealable = Some(Appealable(r, violationLevel = violationLevel)) + ), + Evaluated + ) + } + + def build( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): RuleResult = { + val appealableReason = + FreedomOfSpeechNotReach.extractTweetSafetyLabel(featureMap).map(_.labelType) match { + case Some(label) => + FreedomOfSpeechNotReach.eligibleTweetSafetyLabelTypesToAppealableReason( + label, + violationLevel) + case _ => + AppealableReason.Unspecified(violationLevel.level) + } + + toRuleResult(Reason.fromAppealableReason(appealableReason)) + } + + override def withViolationLevel(violationLevel: ViolationLevel) = { + copy(violationLevel = violationLevel) + } + } + + case class AvoidAction(violationLevel: ViolationLevel = DefaultViolationLevel) + extends FreedomOfSpeechNotReachActionBuilder[Avoid] { + + override def actionType: Class[_] = classOf[Avoid] + + override val actionSeverity = 1 + private def toRuleResult: Reason => RuleResult = Memoize { r => + RuleResult(Avoid(None), Evaluated) + } + + def build(evaluationContext: EvaluationContext, featureMap: Map[Feature[_], _]): RuleResult = { + val appealableReason = + FreedomOfSpeechNotReach.extractTweetSafetyLabel(featureMap).map(_.labelType) match { + case Some(label) => + FreedomOfSpeechNotReach.eligibleTweetSafetyLabelTypesToAppealableReason( + label, + violationLevel) + case _ => + AppealableReason.Unspecified(violationLevel.level) + } + + toRuleResult(Reason.fromAppealableReason(appealableReason)) + } + + override def withViolationLevel(violationLevel: ViolationLevel) = { + copy(violationLevel = violationLevel) + } + } + + case class LimitedEngagementsAction(violationLevel: ViolationLevel = DefaultViolationLevel) + extends FreedomOfSpeechNotReachActionBuilder[LimitedEngagements] { + + override def actionType: Class[_] = classOf[LimitedEngagements] + + override val actionSeverity = 6 + private def toRuleResult: Reason => RuleResult = Memoize { r => + RuleResult(LimitedEngagements(LimitedEngagementReason.NonCompliant, None), Evaluated) + } + + def build(evaluationContext: EvaluationContext, featureMap: Map[Feature[_], _]): RuleResult = { + val appealableReason = + FreedomOfSpeechNotReach.extractTweetSafetyLabel(featureMap).map(_.labelType) match { + case Some(label) => + FreedomOfSpeechNotReach.eligibleTweetSafetyLabelTypesToAppealableReason( + label, + violationLevel) + case _ => + AppealableReason.Unspecified(violationLevel.level) + } + + toRuleResult(Reason.fromAppealableReason(appealableReason)) + } + + override def withViolationLevel(violationLevel: ViolationLevel) = { + copy(violationLevel = violationLevel) + } + } + + case class InterstitialLimitedEngagementsAction( + violationLevel: ViolationLevel = DefaultViolationLevel) + extends FreedomOfSpeechNotReachActionBuilder[InterstitialLimitedEngagements] { + + override def actionType: Class[_] = classOf[InterstitialLimitedEngagements] + + override val actionSeverity = 11 + private def toRuleResult: Reason => RuleResult = Memoize { r => + RuleResult(InterstitialLimitedEngagements(r, None), Evaluated) + } + + def build(evaluationContext: EvaluationContext, featureMap: Map[Feature[_], _]): RuleResult = { + val appealableReason = + FreedomOfSpeechNotReach.extractTweetSafetyLabel(featureMap).map(_.labelType) match { + case Some(label) => + FreedomOfSpeechNotReach.eligibleTweetSafetyLabelTypesToAppealableReason( + label, + violationLevel) + case _ => + AppealableReason.Unspecified(violationLevel.level) + } + + toRuleResult(Reason.fromAppealableReason(appealableReason)) + } + + override def withViolationLevel(violationLevel: ViolationLevel) = { + copy(violationLevel = violationLevel) + } + } + + case class InterstitialLimitedEngagementsAvoidAction( + violationLevel: ViolationLevel = DefaultViolationLevel, + limitedActionStrings: Option[Seq[String]]) + extends FreedomOfSpeechNotReachActionBuilder[TweetInterstitial] { + + override def actionType: Class[_] = classOf[InterstitialLimitedEngagementsAvoidAction] + + override val actionSeverity = 14 + private def toRuleResult: AppealableReason => RuleResult = Memoize { r => + RuleResult( + TweetInterstitial( + interstitial = Some( + Interstitial( + reason = FosnrReason(r), + localizedMessage = None, + )), + softIntervention = None, + limitedEngagements = Some( + LimitedEngagements( + reason = toLimitedEngagementReason(r), + policy = FreedomOfSpeechNotReach.limitedActionConverter(limitedActionStrings))), + downrank = None, + avoid = Some(Avoid(None)), + mediaInterstitial = None, + tweetVisibilityNudge = None + ), + Evaluated + ) + } + + def build(evaluationContext: EvaluationContext, featureMap: Map[Feature[_], _]): RuleResult = { + val appealableReason = + FreedomOfSpeechNotReach.extractTweetSafetyLabel(featureMap).map(_.labelType) match { + case Some(label) => + FreedomOfSpeechNotReach.eligibleTweetSafetyLabelTypesToAppealableReason( + labelType = label, + violationLevel = violationLevel) + case _ => + AppealableReason.Unspecified(violationLevel.level) + } + + toRuleResult(appealableReason) + } + + override def withViolationLevel(violationLevel: ViolationLevel) = { + copy(violationLevel = violationLevel) + } + } + + case class ConversationSectionAbusiveQualityAction( + violationLevel: ViolationLevel = DefaultViolationLevel) + extends FreedomOfSpeechNotReachActionBuilder[ConversationSectionAbusiveQuality.type] { + + override def actionType: Class[_] = ConversationSectionAbusiveQuality.getClass + + override val actionSeverity = 5 + private def toRuleResult: Reason => RuleResult = Memoize { r => + RuleResult(ConversationSectionAbusiveQuality, Evaluated) + } + + def build(evaluationContext: EvaluationContext, featureMap: Map[Feature[_], _]): RuleResult = { + val appealableReason = + FreedomOfSpeechNotReach.extractTweetSafetyLabel(featureMap).map(_.labelType) match { + case Some(label) => + FreedomOfSpeechNotReach.eligibleTweetSafetyLabelTypesToAppealableReason( + label, + violationLevel) + case _ => + AppealableReason.Unspecified(violationLevel.level) + } + + toRuleResult(Reason.fromAppealableReason(appealableReason)) + } + + override def withViolationLevel(violationLevel: ViolationLevel) = { + copy(violationLevel = violationLevel) + } + } + + case class SoftInterventionAvoidAction(violationLevel: ViolationLevel = DefaultViolationLevel) + extends FreedomOfSpeechNotReachActionBuilder[TweetInterstitial] { + + override def actionType: Class[_] = classOf[SoftInterventionAvoidAction] + + override val actionSeverity = 8 + private def toRuleResult: AppealableReason => RuleResult = Memoize { r => + RuleResult( + TweetInterstitial( + interstitial = None, + softIntervention = Some( + SoftIntervention( + reason = toSoftInterventionReason(r), + engagementNudge = false, + suppressAutoplay = true, + warning = None, + detailsUrl = None, + displayType = Some(SoftInterventionDisplayType.Fosnr) + )), + limitedEngagements = None, + downrank = None, + avoid = Some(Avoid(None)), + mediaInterstitial = None, + tweetVisibilityNudge = None, + abusiveQuality = None + ), + Evaluated + ) + } + + def build(evaluationContext: EvaluationContext, featureMap: Map[Feature[_], _]): RuleResult = { + val appealableReason = + FreedomOfSpeechNotReach.extractTweetSafetyLabel(featureMap).map(_.labelType) match { + case Some(label) => + FreedomOfSpeechNotReach.eligibleTweetSafetyLabelTypesToAppealableReason( + label, + violationLevel) + case _ => + AppealableReason.Unspecified(violationLevel.level) + } + + toRuleResult(appealableReason) + } + + override def withViolationLevel(violationLevel: ViolationLevel) = { + copy(violationLevel = violationLevel) + } + } + + case class SoftInterventionAvoidLimitedEngagementsAction( + violationLevel: ViolationLevel = DefaultViolationLevel, + limitedActionStrings: Option[Seq[String]]) + extends FreedomOfSpeechNotReachActionBuilder[TweetInterstitial] { + + override def actionType: Class[_] = classOf[SoftInterventionAvoidLimitedEngagementsAction] + + override val actionSeverity = 13 + private def toRuleResult: AppealableReason => RuleResult = Memoize { r => + RuleResult( + TweetInterstitial( + interstitial = None, + softIntervention = Some( + SoftIntervention( + reason = toSoftInterventionReason(r), + engagementNudge = false, + suppressAutoplay = true, + warning = None, + detailsUrl = None, + displayType = Some(SoftInterventionDisplayType.Fosnr) + )), + limitedEngagements = Some( + LimitedEngagements( + toLimitedEngagementReason(r), + FreedomOfSpeechNotReach.limitedActionConverter(limitedActionStrings))), + downrank = None, + avoid = Some(Avoid(None)), + mediaInterstitial = None, + tweetVisibilityNudge = None, + abusiveQuality = None + ), + Evaluated + ) + } + + def build(evaluationContext: EvaluationContext, featureMap: Map[Feature[_], _]): RuleResult = { + val appealableReason = + FreedomOfSpeechNotReach.extractTweetSafetyLabel(featureMap).map(_.labelType) match { + case Some(label) => + FreedomOfSpeechNotReach.eligibleTweetSafetyLabelTypesToAppealableReason( + label, + violationLevel) + case _ => + AppealableReason.Unspecified(violationLevel.level) + } + + toRuleResult(appealableReason) + } + + override def withViolationLevel(violationLevel: ViolationLevel) = { + copy(violationLevel = violationLevel) + } + } + + case class SoftInterventionAvoidAbusiveQualityReplyAction( + violationLevel: ViolationLevel = DefaultViolationLevel) + extends FreedomOfSpeechNotReachActionBuilder[TweetInterstitial] { + + override def actionType: Class[_] = classOf[SoftInterventionAvoidAbusiveQualityReplyAction] + + override val actionSeverity = 13 + private def toRuleResult: AppealableReason => RuleResult = Memoize { r => + RuleResult( + TweetInterstitial( + interstitial = None, + softIntervention = Some( + SoftIntervention( + reason = toSoftInterventionReason(r), + engagementNudge = false, + suppressAutoplay = true, + warning = None, + detailsUrl = None, + displayType = Some(SoftInterventionDisplayType.Fosnr) + )), + limitedEngagements = None, + downrank = None, + avoid = Some(Avoid(None)), + mediaInterstitial = None, + tweetVisibilityNudge = None, + abusiveQuality = Some(ConversationSectionAbusiveQuality) + ), + Evaluated + ) + } + + def build(evaluationContext: EvaluationContext, featureMap: Map[Feature[_], _]): RuleResult = { + val appealableReason = + FreedomOfSpeechNotReach.extractTweetSafetyLabel(featureMap).map(_.labelType) match { + case Some(label) => + FreedomOfSpeechNotReach.eligibleTweetSafetyLabelTypesToAppealableReason( + label, + violationLevel) + case _ => + AppealableReason.Unspecified(violationLevel.level) + } + + toRuleResult(appealableReason) + } + + override def withViolationLevel(violationLevel: ViolationLevel) = { + copy(violationLevel = violationLevel) + } + } +} + +object FreedomOfSpeechNotReachRules { + + abstract class OnlyWhenAuthorViewerRule( + actionBuilder: ActionBuilder[_ <: Action], + condition: Condition) + extends Rule(actionBuilder, And(Not(NonAuthorViewer), condition)) + + abstract class OnlyWhenNonAuthorViewerRule( + actionBuilder: ActionBuilder[_ <: Action], + condition: Condition) + extends Rule(actionBuilder, And(NonAuthorViewer, condition)) + + case class ViewerIsAuthorAndTweetHasViolationOfLevel( + violationLevel: ViolationLevel, + override val actionBuilder: ActionBuilder[_ <: Action]) + extends OnlyWhenAuthorViewerRule( + actionBuilder, + Condition.TweetHasViolationOfLevel(violationLevel) + ) { + override lazy val name: String = s"ViewerIsAuthorAndTweetHasViolationOf$violationLevel" + + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableFosnrRuleParam, FosnrRulesEnabledParam) + } + + case class ViewerIsFollowerAndTweetHasViolationOfLevel( + violationLevel: ViolationLevel, + override val actionBuilder: ActionBuilder[_ <: Action]) + extends OnlyWhenNonAuthorViewerRule( + actionBuilder, + And( + Condition.TweetHasViolationOfLevel(violationLevel), + ViewerFollowsAuthorOfFosnrViolatingTweet) + ) { + override lazy val name: String = s"ViewerIsFollowerAndTweetHasViolationOf$violationLevel" + + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableFosnrRuleParam, FosnrRulesEnabledParam) + } + + case class ViewerIsNonFollowerNonAuthorAndTweetHasViolationOfLevel( + violationLevel: ViolationLevel, + override val actionBuilder: ActionBuilder[_ <: Action]) + extends OnlyWhenNonAuthorViewerRule( + actionBuilder, + And( + Condition.TweetHasViolationOfLevel(violationLevel), + ViewerDoesNotFollowAuthorOfFosnrViolatingTweet) + ) { + override lazy val name: String = + s"ViewerIsNonFollowerNonAuthorAndTweetHasViolationOf$violationLevel" + + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableFosnrRuleParam, FosnrRulesEnabledParam) + } + + case class ViewerIsNonAuthorAndTweetHasViolationOfLevel( + violationLevel: ViolationLevel, + override val actionBuilder: ActionBuilder[_ <: Action]) + extends OnlyWhenNonAuthorViewerRule( + actionBuilder, + Condition.TweetHasViolationOfLevel(violationLevel) + ) { + override lazy val name: String = + s"ViewerIsNonAuthorAndTweetHasViolationOf$violationLevel" + + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableFosnrRuleParam, FosnrRulesEnabledParam) + } + + case object TweetHasViolationOfAnyLevelFallbackDropRule + extends RuleWithConstantAction( + Drop(reason = NotSupportedOnDevice), + Condition.TweetHasViolationOfAnyLevel + ) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableFosnrRuleParam, FosnrFallbackDropRulesEnabledParam) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/InterstitialIf.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/InterstitialIf.scala new file mode 100644 index 000000000..62785b251 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/InterstitialIf.scala @@ -0,0 +1,43 @@ +package com.twitter.visibility.rules + +import com.twitter.visibility.rules.Condition.And +import com.twitter.visibility.rules.Condition.Not + +object InterstitialIf { + + object ViewerMutedKeyword + extends RuleWithConstantAction( + Interstitial(Reason.MutedKeyword), + And( + Not(Condition.IsFocalTweet), + Condition.ViewerHasMatchingKeywordForTweetReplies, + ) + ) + + object ViewerBlockedAuthor + extends RuleWithConstantAction( + Interstitial(Reason.ViewerBlocksAuthor), + And( + Not(Condition.IsFocalTweet), + Condition.ViewerBlocksAuthor + ) + ) + + object ViewerHardMutedAuthor + extends RuleWithConstantAction( + Interstitial(Reason.ViewerHardMutedAuthor), + And( + Not(Condition.IsFocalTweet), + Condition.ViewerMutesAuthor, + Not( + Condition.ViewerDoesFollowAuthor + ) + ) + ) + + object ViewerReportedAuthor + extends RuleWithConstantAction( + Interstitial(Reason.ViewerReportedAuthor), + Condition.ViewerReportsAuthor + ) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/PublicInterestRules.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/PublicInterestRules.scala new file mode 100644 index 000000000..047e349ae --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/PublicInterestRules.scala @@ -0,0 +1,327 @@ +package com.twitter.visibility.rules + +import com.twitter.guano.commons.thriftscala.PolicyInViolation +import com.twitter.spam.rtf.thriftscala.SafetyResultReason +import com.twitter.util.Memoize +import com.twitter.util.Time +import com.twitter.visibility.common.actions.ComplianceTweetNoticeEventType +import com.twitter.visibility.common.actions.LimitedEngagementReason +import com.twitter.visibility.configapi.params.RuleParam +import com.twitter.visibility.configapi.params.RuleParams.EnableSearchIpiSafeSearchWithoutUserInQueryDropRule +import com.twitter.visibility.features.Feature +import com.twitter.visibility.features.TweetSafetyLabels +import com.twitter.visibility.models.TweetSafetyLabel +import com.twitter.visibility.models.TweetSafetyLabelType +import com.twitter.visibility.rules.Condition.And +import com.twitter.visibility.rules.Condition.LoggedOutOrViewerOptInFiltering +import com.twitter.visibility.rules.Condition.Not +import com.twitter.visibility.rules.Condition.Or +import com.twitter.visibility.rules.Condition.SearchQueryHasUser +import com.twitter.visibility.rules.Condition.TweetComposedAfter +import com.twitter.visibility.rules.Condition.TweetHasLabel +import com.twitter.visibility.rules.Reason._ +import com.twitter.visibility.rules.State.Evaluated + +object PublicInterest { + object PolicyConfig { + val LowQualityProxyLabelStart: Time = Time.fromMilliseconds(1554076800000L) + val DefaultReason: (Reason, Option[LimitedEngagementReason]) = + (OneOff, Some(LimitedEngagementReason.NonCompliant)) + val DefaultPolicyInViolation: PolicyInViolation = PolicyInViolation.OneOff + } + + val policyInViolationToReason: Map[PolicyInViolation, Reason] = Map( + PolicyInViolation.AbusePolicyEpisodic -> AbuseEpisodic, + PolicyInViolation.AbusePolicyEpisodicEncourageSelfharm -> AbuseEpisodicEncourageSelfHarm, + PolicyInViolation.AbusePolicyEpisodicHatefulConduct -> AbuseEpisodicHatefulConduct, + PolicyInViolation.AbusePolicyGratuitousGore -> AbuseGratuitousGore, + PolicyInViolation.AbusePolicyGlorificationofViolence -> AbuseGlorificationOfViolence, + PolicyInViolation.AbusePolicyEncourageMobHarassment -> AbuseMobHarassment, + PolicyInViolation.AbusePolicyMomentofDeathDeceasedUser -> AbuseMomentOfDeathOrDeceasedUser, + PolicyInViolation.AbusePolicyPrivateInformation -> AbusePrivateInformation, + PolicyInViolation.AbusePolicyRighttoPrivacy -> AbuseRightToPrivacy, + PolicyInViolation.AbusePolicyThreattoExpose -> AbuseThreatToExpose, + PolicyInViolation.AbusePolicyViolentSexualConduct -> AbuseViolentSexualConduct, + PolicyInViolation.AbusePolicyViolentThreatsHatefulConduct -> AbuseViolentThreatHatefulConduct, + PolicyInViolation.AbusePolicyViolentThreatorBounty -> AbuseViolentThreatOrBounty, + PolicyInViolation.OneOff -> OneOff, + PolicyInViolation.AbusePolicyElectionInterference -> VotingMisinformation, + PolicyInViolation.MisinformationVoting -> VotingMisinformation, + PolicyInViolation.HackedMaterials -> HackedMaterials, + PolicyInViolation.Scam -> Scams, + PolicyInViolation.PlatformManipulation -> PlatformManipulation, + PolicyInViolation.MisinformationCivic -> MisinfoCivic, + PolicyInViolation.AbusePolicyUkraineCrisisMisinformation -> MisinfoCrisis, + PolicyInViolation.MisinformationGeneric -> MisinfoGeneric, + PolicyInViolation.MisinformationMedical -> MisinfoMedical, + ) + + val reasonToPolicyInViolation: Map[Reason, PolicyInViolation] = Map( + AbuseEpisodic -> PolicyInViolation.AbusePolicyEpisodic, + AbuseEpisodicEncourageSelfHarm -> PolicyInViolation.AbusePolicyEpisodicEncourageSelfharm, + AbuseEpisodicHatefulConduct -> PolicyInViolation.AbusePolicyEpisodicHatefulConduct, + AbuseGratuitousGore -> PolicyInViolation.AbusePolicyGratuitousGore, + AbuseGlorificationOfViolence -> PolicyInViolation.AbusePolicyGlorificationofViolence, + AbuseMobHarassment -> PolicyInViolation.AbusePolicyEncourageMobHarassment, + AbuseMomentOfDeathOrDeceasedUser -> PolicyInViolation.AbusePolicyMomentofDeathDeceasedUser, + AbusePrivateInformation -> PolicyInViolation.AbusePolicyPrivateInformation, + AbuseRightToPrivacy -> PolicyInViolation.AbusePolicyRighttoPrivacy, + AbuseThreatToExpose -> PolicyInViolation.AbusePolicyThreattoExpose, + AbuseViolentSexualConduct -> PolicyInViolation.AbusePolicyViolentSexualConduct, + AbuseViolentThreatHatefulConduct -> PolicyInViolation.AbusePolicyViolentThreatsHatefulConduct, + AbuseViolentThreatOrBounty -> PolicyInViolation.AbusePolicyViolentThreatorBounty, + OneOff -> PolicyInViolation.OneOff, + VotingMisinformation -> PolicyInViolation.MisinformationVoting, + HackedMaterials -> PolicyInViolation.HackedMaterials, + Scams -> PolicyInViolation.Scam, + PlatformManipulation -> PolicyInViolation.PlatformManipulation, + MisinfoCivic -> PolicyInViolation.MisinformationCivic, + MisinfoCrisis -> PolicyInViolation.AbusePolicyUkraineCrisisMisinformation, + MisinfoGeneric -> PolicyInViolation.MisinformationGeneric, + MisinfoMedical -> PolicyInViolation.MisinformationMedical, + ) + + val ReasonToSafetyResultReason: Map[Reason, SafetyResultReason] = Map( + AbuseEpisodic -> SafetyResultReason.Episodic, + AbuseEpisodicEncourageSelfHarm -> SafetyResultReason.AbuseEpisodicEncourageSelfHarm, + AbuseEpisodicHatefulConduct -> SafetyResultReason.AbuseEpisodicHatefulConduct, + AbuseGratuitousGore -> SafetyResultReason.AbuseGratuitousGore, + AbuseGlorificationOfViolence -> SafetyResultReason.AbuseGlorificationOfViolence, + AbuseMobHarassment -> SafetyResultReason.AbuseMobHarassment, + AbuseMomentOfDeathOrDeceasedUser -> SafetyResultReason.AbuseMomentOfDeathOrDeceasedUser, + AbusePrivateInformation -> SafetyResultReason.AbusePrivateInformation, + AbuseRightToPrivacy -> SafetyResultReason.AbuseRightToPrivacy, + AbuseThreatToExpose -> SafetyResultReason.AbuseThreatToExpose, + AbuseViolentSexualConduct -> SafetyResultReason.AbuseViolentSexualConduct, + AbuseViolentThreatHatefulConduct -> SafetyResultReason.AbuseViolentThreatHatefulConduct, + AbuseViolentThreatOrBounty -> SafetyResultReason.AbuseViolentThreatOrBounty, + OneOff -> SafetyResultReason.OneOff, + VotingMisinformation -> SafetyResultReason.VotingMisinformation, + HackedMaterials -> SafetyResultReason.HackedMaterials, + Scams -> SafetyResultReason.Scams, + PlatformManipulation -> SafetyResultReason.PlatformManipulation, + MisinfoCivic -> SafetyResultReason.MisinfoCivic, + MisinfoCrisis -> SafetyResultReason.MisinfoCrisis, + MisinfoGeneric -> SafetyResultReason.MisinfoGeneric, + MisinfoMedical -> SafetyResultReason.MisinfoMedical, + IpiDevelopmentOnly -> SafetyResultReason.DevelopmentOnlyPublicInterest + ) + + val Reasons: Set[Reason] = ReasonToSafetyResultReason.keySet + val SafetyResultReasons: Set[SafetyResultReason] = ReasonToSafetyResultReason.values.toSet + + val SafetyResultReasonToReason: Map[SafetyResultReason, Reason] = + ReasonToSafetyResultReason.map(t => t._2 -> t._1) + + val EligibleTweetSafetyLabelTypes: Seq[TweetSafetyLabelType] = Seq( + TweetSafetyLabelType.LowQuality, + TweetSafetyLabelType.MisinfoCivic, + TweetSafetyLabelType.MisinfoGeneric, + TweetSafetyLabelType.MisinfoMedical, + TweetSafetyLabelType.MisinfoCrisis, + TweetSafetyLabelType.IpiDevelopmentOnly + ) + + private val EligibleTweetSafetyLabelTypesSet = EligibleTweetSafetyLabelTypes.toSet + + def extractTweetSafetyLabel(featureMap: Map[Feature[_], _]): Option[TweetSafetyLabel] = { + val tweetSafetyLabels = featureMap(TweetSafetyLabels) + .asInstanceOf[Seq[TweetSafetyLabel]] + .flatMap { tsl => + if (PublicInterest.EligibleTweetSafetyLabelTypesSet.contains(tsl.labelType)) { + Some(tsl.labelType -> tsl) + } else { + None + } + } + .toMap + + PublicInterest.EligibleTweetSafetyLabelTypes.flatMap(tweetSafetyLabels.get).headOption + } + + def policyToReason(policy: PolicyInViolation): Reason = + policyInViolationToReason.get(policy).getOrElse(PolicyConfig.DefaultReason._1) + + def reasonToPolicy(reason: Reason): PolicyInViolation = + reasonToPolicyInViolation.get(reason).getOrElse(PolicyConfig.DefaultPolicyInViolation) +} + +class PublicInterestActionBuilder[T <: Action]() extends ActionBuilder[T] { + def actionType: Class[_] = classOf[InterstitialLimitedEngagements] + + override val actionSeverity = 11 + def build(evaluationContext: EvaluationContext, featureMap: Map[Feature[_], _]): RuleResult = { + val (reason, limitedEngagementReason) = + PublicInterest.extractTweetSafetyLabel(featureMap).map { tweetSafetyLabel => + (tweetSafetyLabel.labelType, tweetSafetyLabel.source) + } match { + case Some((TweetSafetyLabelType.LowQuality, source)) => + source match { + case Some(source) => + SafetyResultReason.valueOf(source.name) match { + case Some(matchedReason) + if PublicInterest.SafetyResultReasonToReason.contains(matchedReason) => + ( + PublicInterest.SafetyResultReasonToReason(matchedReason), + Some(LimitedEngagementReason.NonCompliant)) + case _ => PublicInterest.PolicyConfig.DefaultReason + } + case _ => PublicInterest.PolicyConfig.DefaultReason + } + + + case Some((TweetSafetyLabelType.MisinfoCivic, source)) => + (Reason.MisinfoCivic, LimitedEngagementReason.fromString(source.map(_.name))) + + case Some((TweetSafetyLabelType.MisinfoCrisis, source)) => + (Reason.MisinfoCrisis, LimitedEngagementReason.fromString(source.map(_.name))) + + case Some((TweetSafetyLabelType.MisinfoGeneric, source)) => + (Reason.MisinfoGeneric, LimitedEngagementReason.fromString(source.map(_.name))) + + case Some((TweetSafetyLabelType.MisinfoMedical, source)) => + (Reason.MisinfoMedical, LimitedEngagementReason.fromString(source.map(_.name))) + + case Some((TweetSafetyLabelType.IpiDevelopmentOnly, _)) => + (Reason.IpiDevelopmentOnly, Some(LimitedEngagementReason.NonCompliant)) + + case _ => + PublicInterest.PolicyConfig.DefaultReason + } + + RuleResult(InterstitialLimitedEngagements(reason, limitedEngagementReason), Evaluated) + } +} + +class PublicInterestComplianceTweetNoticeActionBuilder + extends ActionBuilder[ComplianceTweetNoticePreEnrichment] { + + override def actionType: Class[_] = classOf[ComplianceTweetNoticePreEnrichment] + + override val actionSeverity = 2 + def build(evaluationContext: EvaluationContext, featureMap: Map[Feature[_], _]): RuleResult = { + val reason = + PublicInterest.extractTweetSafetyLabel(featureMap).map { tweetSafetyLabel => + (tweetSafetyLabel.labelType, tweetSafetyLabel.source) + } match { + case Some((TweetSafetyLabelType.LowQuality, source)) => + source match { + case Some(source) => + SafetyResultReason.valueOf(source.name) match { + case Some(matchedReason) + if PublicInterest.SafetyResultReasonToReason.contains(matchedReason) => + PublicInterest.SafetyResultReasonToReason(matchedReason) + case _ => PublicInterest.PolicyConfig.DefaultReason._1 + } + case _ => PublicInterest.PolicyConfig.DefaultReason._1 + } + + + case Some((TweetSafetyLabelType.MisinfoCivic, _)) => + Reason.MisinfoCivic + + case Some((TweetSafetyLabelType.MisinfoCrisis, _)) => + Reason.MisinfoCrisis + + case Some((TweetSafetyLabelType.MisinfoGeneric, _)) => + Reason.MisinfoGeneric + + case Some((TweetSafetyLabelType.MisinfoMedical, _)) => + Reason.MisinfoMedical + + case Some((TweetSafetyLabelType.IpiDevelopmentOnly, _)) => + Reason.IpiDevelopmentOnly + + case _ => + PublicInterest.PolicyConfig.DefaultReason._1 + } + + RuleResult( + ComplianceTweetNoticePreEnrichment(reason, ComplianceTweetNoticeEventType.PublicInterest), + Evaluated) + } +} + +class PublicInterestDropActionBuilder extends ActionBuilder[Drop] { + + override def actionType: Class[_] = classOf[Drop] + + override val actionSeverity = 16 + private def toRuleResult: Reason => RuleResult = Memoize { r => RuleResult(Drop(r), Evaluated) } + + def build(evaluationContext: EvaluationContext, featureMap: Map[Feature[_], _]): RuleResult = { + val reason = PublicInterest.extractTweetSafetyLabel(featureMap).map(_.labelType) match { + case Some(TweetSafetyLabelType.LowQuality) => + Reason.OneOff + + case Some(TweetSafetyLabelType.MisinfoCivic) => + Reason.MisinfoCivic + + case Some(TweetSafetyLabelType.MisinfoCrisis) => + Reason.MisinfoCrisis + + case Some(TweetSafetyLabelType.MisinfoGeneric) => + Reason.MisinfoGeneric + + case Some(TweetSafetyLabelType.MisinfoMedical) => + Reason.MisinfoMedical + + case _ => + Reason.OneOff + } + + toRuleResult(reason) + } +} + +object PublicInterestRules { + + object AbusePolicyEpisodicTweetLabelInterstitialRule + extends Rule( + new PublicInterestActionBuilder(), + And( + TweetComposedAfter(PublicInterest.PolicyConfig.LowQualityProxyLabelStart), + Or( + PublicInterest.EligibleTweetSafetyLabelTypes.map(TweetHasLabel(_)): _* + ) + ) + ) + + object AbusePolicyEpisodicTweetLabelComplianceTweetNoticeRule + extends Rule( + new PublicInterestComplianceTweetNoticeActionBuilder(), + And( + TweetComposedAfter(PublicInterest.PolicyConfig.LowQualityProxyLabelStart), + Or( + PublicInterest.EligibleTweetSafetyLabelTypes.map(TweetHasLabel(_)): _* + ) + ) + ) + + object AbusePolicyEpisodicTweetLabelDropRule + extends Rule( + new PublicInterestDropActionBuilder(), + And( + TweetComposedAfter(PublicInterest.PolicyConfig.LowQualityProxyLabelStart), + Or( + PublicInterest.EligibleTweetSafetyLabelTypes.map(TweetHasLabel(_)): _* + ) + ) + ) + + object SearchIpiSafeSearchWithoutUserInQueryDropRule + extends Rule( + new PublicInterestDropActionBuilder(), + And( + TweetComposedAfter(PublicInterest.PolicyConfig.LowQualityProxyLabelStart), + Or( + PublicInterest.EligibleTweetSafetyLabelTypes.map(TweetHasLabel(_)): _* + ), + LoggedOutOrViewerOptInFiltering, + Not(SearchQueryHasUser) + ) + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq( + EnableSearchIpiSafeSearchWithoutUserInQueryDropRule) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/Rule.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/Rule.scala new file mode 100644 index 000000000..ab1c21e13 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/Rule.scala @@ -0,0 +1,215 @@ +package com.twitter.visibility.rules + +import com.twitter.abdecider.LoggingABDecider +import com.twitter.timelines.configapi.HasParams.DependencyProvider +import com.twitter.timelines.configapi.Params +import com.twitter.visibility.configapi.params.RuleParam +import com.twitter.visibility.configapi.params.RuleParams +import com.twitter.visibility.configapi.params.RuleParams.EnableLikelyIvsUserLabelDropRule +import com.twitter.visibility.features._ +import com.twitter.visibility.models.UserLabelValue +import com.twitter.visibility.models.UserLabelValue.LikelyIvs +import com.twitter.visibility.rules.Condition._ +import com.twitter.visibility.rules.Reason.Unspecified +import com.twitter.visibility.rules.RuleActionSourceBuilder.UserSafetyLabelSourceBuilder +import com.twitter.visibility.rules.State._ +import com.twitter.visibility.util.NamingUtils + +trait WithGate { + def enabled: Seq[RuleParam[Boolean]] = Seq(RuleParams.True) + + def isEnabled(params: Params): Boolean = + enabled.forall(enabledParam => params(enabledParam)) + + def holdbacks: Seq[RuleParam[Boolean]] = Seq(RuleParams.False) + + final def shouldHoldback: DependencyProvider[Boolean] = + holdbacks.foldLeft(DependencyProvider.from(RuleParams.False)) { (dp, holdbackParam) => + dp.or(DependencyProvider.from(holdbackParam)) + } + + protected def enableFailClosed: Seq[RuleParam[Boolean]] = Seq(RuleParams.False) + def shouldFailClosed(params: Params): Boolean = + enableFailClosed.forall(fcParam => params(fcParam)) +} + +abstract class ActionBuilder[T <: Action] { + def actionType: Class[_] + + val actionSeverity: Int + def build(evaluationContext: EvaluationContext, featureMap: Map[Feature[_], _]): RuleResult +} + +object ActionBuilder { + def apply[T <: Action](action: T): ActionBuilder[T] = action match { + case _: InterstitialLimitedEngagements => new PublicInterestActionBuilder() + case _ => new ConstantActionBuilder(action) + } +} + +class ConstantActionBuilder[T <: Action](action: T) extends ActionBuilder[T] { + private val result = RuleResult(action, Evaluated) + + def actionType: Class[_] = action.getClass + + override val actionSeverity = action.severity + def build(evaluationContext: EvaluationContext, featureMap: Map[Feature[_], _]): RuleResult = + result +} + +object ConstantActionBuilder { + def unapply[T <: Action](builder: ConstantActionBuilder[T]): Option[Action] = Some( + builder.result.action) +} + +abstract class Rule(val actionBuilder: ActionBuilder[_ <: Action], val condition: Condition) + extends WithGate { + + import Rule._ + def isExperimental: Boolean = false + + def actionSourceBuilder: Option[RuleActionSourceBuilder] = None + + lazy val name: String = NamingUtils.getFriendlyName(this) + + val featureDependencies: Set[Feature[_]] = condition.features + + val optionalFeatureDependencies: Set[Feature[_]] = condition.optionalFeatures + + def preFilter( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], Any], + abDecider: LoggingABDecider + ): PreFilterResult = + condition.preFilter(evaluationContext, featureMap) + + def actWhen(evaluationContext: EvaluationContext, featureMap: Map[Feature[_], _]): Boolean = + condition(evaluationContext, featureMap).asBoolean + + val fallbackActionBuilder: Option[ActionBuilder[_ <: Action]] = None + + final def evaluate( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], _] + ): RuleResult = { + val missingFeatures = featureDependencies.filterNot(featureMap.contains) + + if (missingFeatures.nonEmpty) { + fallbackActionBuilder match { + case Some(fallbackAction) => + fallbackAction.build(evaluationContext, featureMap) + case None => + RuleResult(NotEvaluated, MissingFeature(missingFeatures)) + } + } else { + try { + val act = actWhen(evaluationContext, featureMap) + if (!act) { + EvaluatedRuleResult + } else if (shouldHoldback(evaluationContext)) { + + HeldbackRuleResult + } else { + actionBuilder.build(evaluationContext, featureMap) + } + } catch { + case t: Throwable => + RuleResult(NotEvaluated, RuleFailed(t)) + } + } + } +} + +trait ExperimentalRule extends Rule { + override def isExperimental: Boolean = true +} + +object Rule { + + val HeldbackRuleResult: RuleResult = RuleResult(Allow, Heldback) + val EvaluatedRuleResult: RuleResult = RuleResult(Allow, Evaluated) + val DisabledRuleResult: RuleResult = RuleResult(NotEvaluated, Disabled) + + def unapply(rule: Rule): Option[(ActionBuilder[_ <: Action], Condition)] = + Some((rule.actionBuilder, rule.condition)) +} + +abstract class RuleWithConstantAction(val action: Action, override val condition: Condition) + extends Rule(ActionBuilder(action), condition) + +abstract class UserHasLabelRule(action: Action, userLabelValue: UserLabelValue) + extends RuleWithConstantAction(action, AuthorHasLabel(userLabelValue)) { + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + UserSafetyLabelSourceBuilder(userLabelValue)) +} + +abstract class ConditionWithUserLabelRule( + action: Action, + condition: Condition, + userLabelValue: UserLabelValue) + extends Rule( + ActionBuilder(action), + And(NonAuthorViewer, AuthorHasLabel(userLabelValue), condition)) { + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + UserSafetyLabelSourceBuilder(userLabelValue)) +} + +abstract class WhenAuthorUserLabelPresentRule(action: Action, userLabelValue: UserLabelValue) + extends ConditionWithUserLabelRule(action, Condition.True, userLabelValue) + +abstract class ConditionWithNotInnerCircleOfFriendsRule( + action: Action, + condition: Condition) + extends RuleWithConstantAction( + action, + And(Not(DoesHaveInnerCircleOfFriendsRelationship), condition)) + +abstract class AuthorLabelWithNotInnerCircleOfFriendsRule( + action: Action, + userLabelValue: UserLabelValue) + extends ConditionWithNotInnerCircleOfFriendsRule( + action, + AuthorHasLabel(userLabelValue) + ) { + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + UserSafetyLabelSourceBuilder(userLabelValue)) +} + +abstract class OnlyWhenNotAuthorViewerRule(action: Action, condition: Condition) + extends RuleWithConstantAction(action, And(NonAuthorViewer, condition)) + +abstract class AuthorLabelAndNonFollowerViewerRule(action: Action, userLabelValue: UserLabelValue) + extends ConditionWithUserLabelRule(action, LoggedOutOrViewerNotFollowingAuthor, userLabelValue) + +abstract class AlwaysActRule(action: Action) extends Rule(ActionBuilder(action), Condition.True) + +abstract class ViewerOptInBlockingOnSearchRule(action: Action, condition: Condition) + extends OnlyWhenNotAuthorViewerRule( + action, + And(condition, ViewerOptInBlockingOnSearch) + ) + +abstract class ViewerOptInFilteringOnSearchRule(action: Action, condition: Condition) + extends OnlyWhenNotAuthorViewerRule( + action, + And(condition, ViewerOptInFilteringOnSearch) + ) + +abstract class ViewerOptInFilteringOnSearchUserLabelRule( + action: Action, + userLabelValue: UserLabelValue, + prerequisiteCondition: Condition = True) + extends ConditionWithUserLabelRule( + action, + And(prerequisiteCondition, LoggedOutOrViewerOptInFiltering), + userLabelValue + ) + +abstract class LikelyIvsLabelNonFollowerDropRule + extends AuthorLabelAndNonFollowerViewerRule( + Drop(Unspecified), + LikelyIvs + ) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableLikelyIvsUserLabelDropRule) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/RuleActionSourceBuilder.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/RuleActionSourceBuilder.scala new file mode 100644 index 000000000..72d54c677 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/RuleActionSourceBuilder.scala @@ -0,0 +1,97 @@ +package com.twitter.visibility.rules + +import com.twitter.escherbird.thriftscala.TweetEntityAnnotation +import com.twitter.gizmoduck.thriftscala.Label +import com.twitter.spam.rtf.thriftscala.BotMakerAction +import com.twitter.spam.rtf.thriftscala.SafetyLabelSource +import com.twitter.spam.rtf.thriftscala.SemanticCoreAction +import com.twitter.visibility.common.actions.EscherbirdAnnotation +import com.twitter.visibility.common.actions.SoftInterventionReason +import com.twitter.visibility.configapi.configs.DeciderKey +import com.twitter.visibility.features.AuthorUserLabels +import com.twitter.visibility.features.Feature +import com.twitter.visibility.features.TweetSafetyLabels +import com.twitter.visibility.logging.thriftscala.ActionSource +import com.twitter.visibility.models.LabelSource._ +import com.twitter.visibility.models.TweetSafetyLabel +import com.twitter.visibility.models.TweetSafetyLabelType +import com.twitter.visibility.models.UserLabel +import com.twitter.visibility.models.UserLabelValue + +sealed trait RuleActionSourceBuilder { + def build(resolvedFeatureMap: Map[Feature[_], Any], verdict: Action): Option[ActionSource] + +} + +object RuleActionSourceBuilder { + + case class TweetSafetyLabelSourceBuilder(tweetSafetyLabelType: TweetSafetyLabelType) + extends RuleActionSourceBuilder { + override def build( + resolvedFeatureMap: Map[Feature[_], Any], + verdict: Action + ): Option[ActionSource] = { + resolvedFeatureMap + .getOrElse(TweetSafetyLabels, Seq.empty[TweetSafetyLabel]) + .asInstanceOf[Seq[TweetSafetyLabel]] + .find(_.labelType == tweetSafetyLabelType) + .flatMap(_.safetyLabelSource) + .map(ActionSource.SafetyLabelSource(_)) + } + } + + case class UserSafetyLabelSourceBuilder(userLabel: UserLabelValue) + extends RuleActionSourceBuilder { + override def build( + resolvedFeatureMap: Map[Feature[_], Any], + verdict: Action + ): Option[ActionSource] = { + resolvedFeatureMap + .getOrElse(AuthorUserLabels, Seq.empty[Label]) + .asInstanceOf[Seq[Label]] + .map(UserLabel.fromThrift) + .find(_.labelValue == userLabel) + .flatMap(_.source) + .collect { + case BotMakerRule(ruleId) => + ActionSource.SafetyLabelSource(SafetyLabelSource.BotMakerAction(BotMakerAction(ruleId))) + } + } + } + + case class SemanticCoreActionSourceBuilder() extends RuleActionSourceBuilder { + override def build( + resolvedFeatureMap: Map[Feature[_], Any], + verdict: Action + ): Option[ActionSource] = { + verdict match { + case softIntervention: SoftIntervention => + getSemanticCoreActionSourceOption(softIntervention) + case tweetInterstitial: TweetInterstitial => + tweetInterstitial.softIntervention.flatMap(getSemanticCoreActionSourceOption) + case _ => None + } + } + + def getSemanticCoreActionSourceOption( + softIntervention: SoftIntervention + ): Option[ActionSource] = { + val siReason = softIntervention.reason + .asInstanceOf[SoftInterventionReason.EscherbirdAnnotations] + val firstAnnotation: Option[EscherbirdAnnotation] = + siReason.escherbirdAnnotations.headOption + + firstAnnotation.map { annotation => + ActionSource.SafetyLabelSource( + SafetyLabelSource.SemanticCoreAction(SemanticCoreAction( + TweetEntityAnnotation(annotation.groupId, annotation.domainId, annotation.entityId)))) + } + } + } +} + +trait DoesLogVerdict {} + +trait DoesLogVerdictDecidered extends DoesLogVerdict { + def verdictLogDeciderKey: DeciderKey.Value +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/RuleBase.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/RuleBase.scala new file mode 100644 index 000000000..66cbae0d1 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/RuleBase.scala @@ -0,0 +1,238 @@ +package com.twitter.visibility.rules + +import com.twitter.visibility.features.AuthorScreenName +import com.twitter.visibility.features.Feature +import com.twitter.visibility.features.FeatureMap +import com.twitter.visibility.features.RawQuery +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.SafetyLevel._ + +object RuleBase { + + val DeprecatedFeatures: Seq[Feature[_]] = + Seq(AuthorScreenName, RawQuery) + + val RuleMap: Map[SafetyLevel, VisibilityPolicy] = Map( + AccessInternalPromotedContent -> InternalPromotedContentPolicy, + AllSubscribedLists -> AllSubscribedListsPolicy, + AdsBusinessSettings -> AdsBusinessSettingsPolicy, + AdsCampaign -> AdsCampaignPolicy, + AdsManager -> AdsManagerPolicy, + AdsReportingDashboard -> AdsReportingDashboardPolicy, + Appeals -> AppealsPolicy, + ArticleTweetTimeline -> ArticleTweetTimelinePolicy, + BaseQig -> BaseQigPolicy, + BirdwatchNoteAuthor -> BirdwatchNoteAuthorPolicy, + BirdwatchNoteTweetsTimeline -> BirdwatchNoteTweetsTimelinePolicy, + BirdwatchNeedsYourHelpNotifications -> BirdwatchNeedsYourHelpNotificationsPolicy, + BlockMuteUsersTimeline -> BlockMuteUsersTimelinePolicy, + BrandSafety -> BrandSafetyPolicy, + CardPollVoting -> CardPollVotingPolicy, + CardsService -> CardsServicePolicy, + Communities -> CommunitiesPolicy, + ContentControlToolInstall -> ContentControlToolInstallPolicy, + ConversationFocalPrehydration -> ConversationFocalPrehydrationPolicy, + ConversationFocalTweet -> ConversationFocalTweetPolicy, + ConversationInjectedTweet -> ConversationInjectedTweetPolicy, + ConversationReply -> ConversationReplyPolicy, + CuratedTrendsRepresentativeTweet -> CuratedTrendsRepresentativeTweetPolicy, + CurationPolicyViolations -> CurationPolicyViolationsPolicy, + FollowingAndFollowersUserList -> FollowingAndFollowersUserListPolicy, + DeprecatedSafetyLevel -> FilterNonePolicy, + DevPlatformGetListTweets -> DevPlatformGetListTweetsPolicy, + DesFollowingAndFollowersUserList -> FollowingAndFollowersUserListPolicy, + DesHomeTimeline -> DESHomeTimelinePolicy, + DesQuoteTweetTimeline -> DesQuoteTweetTimelinePolicy, + DesRealtime -> DESRealtimePolicy, + DesRealtimeSpamEnrichment -> DESRealtimeSpamEnrichmentPolicy, + DesRealtimeTweetFilter -> DESRealtimeSpamEnrichmentPolicy, + DesRetweetingUsers -> DESRetweetingUsersPolicy, + DesTweetDetail -> DesTweetDetailPolicy, + DesTweetLikingUsers -> DESTweetLikingUsersPolicy, + DesUserBookmarks -> DESUserBookmarksPolicy, + DesUserLikedTweets -> DESUserLikedTweetsPolicy, + DesUserMentions -> DESUserMentionsPolicy, + DesUserTweets -> DESUserTweetsPolicy, + DevPlatformComplianceStream -> DevPlatformComplianceStreamPolicy, + DirectMessages -> DirectMessagesPolicy, + DirectMessagesConversationList -> DirectMessagesConversationListPolicy, + DirectMessagesConversationTimeline -> DirectMessagesConversationTimelinePolicy, + DirectMessagesInbox -> DirectMessagesInboxPolicy, + DirectMessagesMutedUsers -> DirectMessagesMutedUsersPolicy, + DirectMessagesPinned -> DirectMessagesPinnedPolicy, + DirectMessagesSearch -> DirectMessagesSearchPolicy, + EditHistoryTimeline -> EditHistoryTimelinePolicy, + ElevatedQuoteTweetTimeline -> ElevatedQuoteTweetTimelinePolicy, + EmbeddedTweet -> EmbeddedTweetsPolicy, + EmbedsPublicInterestNotice -> EmbedsPublicInterestNoticePolicy, + EmbedTweetMarkup -> EmbedTweetMarkupPolicy, + WritePathLimitedActionsEnforcement -> WritePathLimitedActionsEnforcementPolicy, + FilterAll -> FilterAllPolicy, + FilterAllPlaceholder -> FilterAllPolicy, + FilterNone -> FilterNonePolicy, + FilterDefault -> FilterDefaultPolicy, + FollowedTopicsTimeline -> FollowedTopicsTimelinePolicy, + FollowerConnections -> FollowerConnectionsPolicy, + ForDevelopmentOnly -> ForDevelopmentOnlyPolicy, + FriendsFollowingList -> FriendsFollowingListPolicy, + GraphqlDefault -> GraphqlDefaultPolicy, + GryphonDecksAndColumns -> GryphonDecksAndColumnsSharingPolicy, + HumanizationNudge -> HumanizationNudgePolicy, + KitchenSinkDevelopment -> KitchenSinkDevelopmentPolicy, + ListHeader -> ListHeaderPolicy, + ListMemberships -> ListMembershipsPolicy, + ListOwnerships -> ListOwnershipsPolicy, + ListRecommendations -> ListRecommendationsPolicy, + ListSearch -> ListSearchPolicy, + ListSubscriptions -> ListSubscriptionsPolicy, + LivePipelineEngagementCounts -> LivePipelineEngagementCountsPolicy, + LiveVideoTimeline -> LiveVideoTimelinePolicy, + MagicRecs -> MagicRecsPolicy, + MagicRecsAggressive -> MagicRecsAggressivePolicy, + MagicRecsAggressiveV2 -> MagicRecsAggressiveV2Policy, + MagicRecsV2 -> MagicRecsV2Policy, + Minimal -> MinimalPolicy, + ModeratedTweetsTimeline -> ModeratedTweetsTimelinePolicy, + Moments -> MomentsPolicy, + NearbyTimeline -> NearbyTimelinePolicy, + NewUserExperience -> NewUserExperiencePolicy, + NotificationsIbis -> NotificationsIbisPolicy, + NotificationsPlatform -> NotificationsPlatformPolicy, + NotificationsPlatformPush -> NotificationsPlatformPushPolicy, + NotificationsQig -> NotificationsQigPolicy, + NotificationsRead -> NotificationsReadPolicy, + NotificationsTimelineDeviceFollow -> NotificationsTimelineDeviceFollowPolicy, + NotificationsWrite -> NotificationsWritePolicy, + NotificationsWriterV2 -> NotificationsWriterV2Policy, + NotificationsWriterTweetHydrator -> NotificationsWriterTweetHydratorPolicy, + ProfileMixerMedia -> ProfileMixerMediaPolicy, + ProfileMixerFavorites -> ProfileMixerFavoritesPolicy, + QuickPromoteTweetEligibility -> QuickPromoteTweetEligibilityPolicy, + QuoteTweetTimeline -> QuoteTweetTimelinePolicy, + QuotedTweetRules -> QuotedTweetRulesPolicy, + Recommendations -> RecommendationsPolicy, + RecosVideo -> RecosVideoPolicy, + RecosWritePath -> RecosWritePathPolicy, + RepliesGrouping -> RepliesGroupingPolicy, + ReportCenter -> ReportCenterPolicy, + ReturningUserExperience -> ReturningUserExperiencePolicy, + ReturningUserExperienceFocalTweet -> ReturningUserExperienceFocalTweetPolicy, + Revenue -> RevenuePolicy, + RitoActionedTweetTimeline -> RitoActionedTweetTimelinePolicy, + SearchHydration -> SearchHydrationPolicy, + SearchMixerSrpMinimal -> SearchMixerSrpMinimalPolicy, + SearchMixerSrpStrict -> SearchMixerSrpStrictPolicy, + SearchLatest -> SearchLatestPolicy, + SearchPeopleSrp -> SearchPeopleSrpPolicy, + SearchPeopleTypeahead -> SearchPeopleTypeaheadPolicy, + SearchPhoto -> SearchPhotoPolicy, + SearchTrendTakeoverPromotedTweet -> SearchTrendTakeoverPromotedTweetPolicy, + SearchTop -> SearchTopPolicy, + SearchTopQig -> SearchTopQigPolicy, + SearchVideo -> SearchVideoPolicy, + SearchBlenderUserRules -> SearchBlenderUserRulesPolicy, + SearchLatestUserRules -> SearchLatestUserRulesPolicy, + ShoppingManagerSpyMode -> ShoppingManagerSpyModePolicy, + SignalsReactions -> SignalsReactionsPolicy, + SignalsTweetReactingUsers -> SignalsTweetReactingUsersPolicy, + SocialProof -> SocialProofPolicy, + SoftInterventionPivot -> SoftInterventionPivotPolicy, + SpaceFleetline -> SpaceFleetlinePolicy, + SpaceHomeTimelineUpranking -> SpaceHomeTimelineUprankingPolicy, + SpaceJoinScreen -> SpaceJoinScreenPolicy, + SpaceNotifications -> SpaceNotificationsPolicy, + Spaces -> SpacesPolicy, + SpacesParticipants -> SpacesParticipantsPolicy, + SpacesSellerApplicationStatus -> SpacesSellerApplicationStatusPolicy, + SpacesSharing -> SpacesSharingPolicy, + SpaceTweetAvatarHomeTimeline -> SpaceTweetAvatarHomeTimelinePolicy, + StickersTimeline -> StickersTimelinePolicy, + StratoExtLimitedEngagements -> StratoExtLimitedEngagementsPolicy, + StreamServices -> StreamServicesPolicy, + SuperFollowerConnections -> SuperFollowerConnectionsPolicy, + SuperLike -> SuperLikePolicy, + Test -> TestPolicy, + TimelineContentControls -> TimelineContentControlsPolicy, + TimelineConversations -> TimelineConversationsPolicy, + TimelineConversationsDownranking -> TimelineConversationsDownrankingPolicy, + TimelineConversationsDownrankingMinimal -> TimelineConversationsDownrankingMinimalPolicy, + TimelineFollowingActivity -> TimelineFollowingActivityPolicy, + TimelineHome -> TimelineHomePolicy, + TimelineHomeCommunities -> TimelineHomeCommunitiesPolicy, + TimelineHomeHydration -> TimelineHomeHydrationPolicy, + TimelineHomePromotedHydration -> TimelineHomePromotedHydrationPolicy, + TimelineHomeRecommendations -> TimelineHomeRecommendationsPolicy, + TimelineHomeTopicFollowRecommendations -> TimelineHomeTopicFollowRecommendationsPolicy, + TimelineScorer -> TimelineScorerPolicy, + TopicsLandingPageTopicRecommendations -> TopicsLandingPageTopicRecommendationsPolicy, + ExploreRecommendations -> ExploreRecommendationsPolicy, + TimelineInjection -> TimelineInjectionPolicy, + TimelineMentions -> TimelineMentionsPolicy, + TimelineModeratedTweetsHydration -> TimelineModeratedTweetsHydrationPolicy, + TimelineHomeLatest -> TimelineHomeLatestPolicy, + TimelineLikedBy -> TimelineLikedByPolicy, + TimelineRetweetedBy -> TimelineRetweetedByPolicy, + TimelineSuperLikedBy -> TimelineSuperLikedByPolicy, + TimelineBookmark -> TimelineBookmarkPolicy, + TimelineMedia -> TimelineMediaPolicy, + TimelineReactiveBlending -> TimelineReactiveBlendingPolicy, + TimelineFavorites -> TimelineFavoritesPolicy, + TimelineFavoritesSelfView -> TimelineFavoritesSelfViewPolicy, + TimelineLists -> TimelineListsPolicy, + TimelineProfile -> TimelineProfilePolicy, + TimelineProfileAll -> TimelineProfileAllPolicy, + TimelineProfileSpaces -> TimelineProfileSpacesPolicy, + TimelineProfileSuperFollows -> TimelineProfileSuperFollowsPolicy, + TimelineFocalTweet -> TimelineFocalTweetPolicy, + Tombstoning -> TombstoningPolicy, + TopicRecommendations -> TopicRecommendationsPolicy, + TrendsRepresentativeTweet -> TrendsRepresentativeTweetPolicy, + TrustedFriendsUserList -> TrustedFriendsUserListPolicy, + TweetDetail -> TweetDetailPolicy, + TweetDetailNonToo -> TweetDetailNonTooPolicy, + TweetDetailWithInjectionsHydration -> TweetDetailWithInjectionsHydrationPolicy, + TweetEngagers -> TweetEngagersPolicy, + TweetReplyNudge -> TweetReplyNudgePolicy, + TweetScopedTimeline -> TweetScopedTimelinePolicy, + TweetWritesApi -> TweetWritesApiPolicy, + TwitterArticleCompose -> TwitterArticleComposePolicy, + TwitterArticleProfileTab -> TwitterArticleProfileTabPolicy, + TwitterArticleRead -> TwitterArticleReadPolicy, + UserMilestoneRecommendation -> UserMilestoneRecommendationPolicy, + UserProfileHeader -> UserProfileHeaderPolicy, + UserScopedTimeline -> UserScopedTimelinePolicy, + UserSearchSrp -> UserSearchSrpPolicy, + UserSearchTypeahead -> UserSearchTypeaheadPolicy, + UserSelfViewOnly -> UserSelfViewOnlyPolicy, + UserSettings -> UserSettingsPolicy, + VideoAds -> VideoAdsPolicy, + ZipbirdConsumerArchives -> ZipbirdConsumerArchivesPolicy, + TweetAward -> TweetAwardPolicy, + ) + + def removeUnusedFeaturesFromFeatureMap( + featureMap: FeatureMap, + rules: Seq[Rule], + ): FeatureMap = { + val featuresInSafetyLevel: Set[Feature[_]] = + RuleBase.getFeaturesForRules(rules) + val filteredMap = featureMap.map.filterKeys(featuresInSafetyLevel.contains(_)) + + new FeatureMap(filteredMap, featureMap.constantMap) + } + + def getFeaturesForRules(rules: Seq[Rule]): Set[Feature[_]] = { + rules.flatMap { r: Rule => + r.featureDependencies ++ r.optionalFeatureDependencies + }.toSet + } + + def hasTweetRules(safetyLevel: SafetyLevel): Boolean = RuleMap(safetyLevel).tweetRules.nonEmpty + def hasUserRules(safetyLevel: SafetyLevel): Boolean = RuleMap(safetyLevel).userRules.nonEmpty + def hasCardRules(safetyLevel: SafetyLevel): Boolean = RuleMap(safetyLevel).cardRules.nonEmpty + def hasDmRules(safetyLevel: SafetyLevel): Boolean = RuleMap(safetyLevel).dmRules.nonEmpty + def hasDmConversationRules(safetyLevel: SafetyLevel): Boolean = RuleMap( + safetyLevel).dmConversationRules.nonEmpty + def hasDmEventRules(safetyLevel: SafetyLevel): Boolean = RuleMap( + safetyLevel).dmEventRules.nonEmpty +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/Rules.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/Rules.scala new file mode 100644 index 000000000..9e8fa1c38 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/Rules.scala @@ -0,0 +1,315 @@ +package com.twitter.visibility.rules + +import com.twitter.visibility.configapi.params.RuleParam +import com.twitter.visibility.configapi.params.RuleParams +import com.twitter.visibility.configapi.params.RuleParams.EnableAuthorBlocksViewerDropRuleParam +import com.twitter.visibility.configapi.params.RuleParams.EnableInnerQuotedTweetViewerBlocksAuthorInterstitialRuleParam +import com.twitter.visibility.configapi.params.RuleParams.EnableInnerQuotedTweetViewerMutesAuthorInterstitialRuleParam +import com.twitter.visibility.configapi.params.RuleParams.EnableTimelineHomePromotedTweetHealthEnforcementRules +import com.twitter.visibility.configapi.params.RuleParams.EnableViewerIsSoftUserDropRuleParam +import com.twitter.visibility.configapi.params.RuleParams.PromotedTweetHealthEnforcementHoldback +import com.twitter.visibility.rules.Condition.And +import com.twitter.visibility.rules.Condition.IsQuotedInnerTweet +import com.twitter.visibility.rules.Condition.NonAuthorViewer +import com.twitter.visibility.rules.Condition.Not +import com.twitter.visibility.rules.Condition.Retweet +import com.twitter.visibility.rules.Condition.SoftViewer +import com.twitter.visibility.rules.Reason._ + +object DropAllRule + extends AlwaysActRule( + Drop(Unspecified) + ) + +object AllowAllRule + extends AlwaysActRule( + Allow + ) + +object TestRule + extends AlwaysActRule( + Drop(Unspecified) + ) + +object DeactivatedAuthorRule + extends OnlyWhenNotAuthorViewerRule( + Drop(DeactivatedAuthor), + Condition.DeactivatedAuthor + ) + +object ErasedAuthorRule + extends OnlyWhenNotAuthorViewerRule( + Drop(ErasedAuthor), + Condition.ErasedAuthor + ) + +object OffboardedAuthorRule + extends OnlyWhenNotAuthorViewerRule( + Drop(OffboardedAuthor), + Condition.OffboardedAuthor + ) + +object DropNsfwUserAuthorRule + extends OnlyWhenNotAuthorViewerRule( + Drop(Nsfw), + Condition.NsfwUserAuthor + ) + +object DropNsfwUserAuthorViewerOptInFilteringOnSearchRule + extends ViewerOptInFilteringOnSearchRule( + Drop(Nsfw), + Condition.NsfwUserAuthor + ) + +object InterstitialNsfwUserAuthorRule + extends OnlyWhenNotAuthorViewerRule( + Interstitial(Nsfw), + Condition.NsfwUserAuthor + ) + +object DropNsfwAdminAuthorRule + extends OnlyWhenNotAuthorViewerRule( + Drop(Nsfw), + Condition.NsfwAdminAuthor + ) + +object DropNsfwAdminAuthorViewerOptInFilteringOnSearchRule + extends ViewerOptInFilteringOnSearchRule( + Drop(Nsfw), + Condition.NsfwAdminAuthor + ) + +object InterstitialNsfwAdminAuthorRule + extends OnlyWhenNotAuthorViewerRule( + Interstitial(Nsfw), + Condition.NsfwAdminAuthor + ) + +object ProtectedAuthorDropRule + extends RuleWithConstantAction( + Drop(Reason.ProtectedAuthor), + And(Condition.LoggedOutOrViewerNotFollowingAuthor, Condition.ProtectedAuthor) + ) + +object ProtectedAuthorTombstoneRule + extends RuleWithConstantAction( + Tombstone(Epitaph.Protected), + And(Condition.LoggedOutOrViewerNotFollowingAuthor, Condition.ProtectedAuthor) + ) + +object DropAllProtectedAuthorRule + extends RuleWithConstantAction( + Drop(Reason.ProtectedAuthor), + Condition.ProtectedAuthor + ) { + override def enableFailClosed: Seq[RuleParam[Boolean]] = Seq(RuleParams.True) +} + +object ProtectedQuoteTweetAuthorRule + extends RuleWithConstantAction( + Drop(Reason.ProtectedAuthor), + And(Condition.OuterAuthorNotFollowingAuthor, Condition.ProtectedAuthor) + ) + +object DropProtectedViewerIfPresentRule + extends RuleWithConstantAction( + Drop(Reason.Unspecified), + And(Condition.LoggedInViewer, Condition.ProtectedViewer) + ) { + override def enableFailClosed: Seq[RuleParam[Boolean]] = Seq(RuleParams.True) +} + +object SuspendedAuthorRule + extends OnlyWhenNotAuthorViewerRule( + Drop(SuspendedAuthor), + Condition.SuspendedAuthor + ) + +object SuspendedViewerRule + extends OnlyWhenNotAuthorViewerRule( + Drop(Unspecified), + Condition.SuspendedViewer + ) + +object DeactivatedViewerRule + extends OnlyWhenNotAuthorViewerRule( + Drop(Unspecified), + Condition.DeactivatedViewer + ) + +object ViewerIsUnmentionedRule + extends OnlyWhenNotAuthorViewerRule( + Drop(Reason.ViewerIsUnmentioned), + Condition.ViewerIsUnmentioned + ) + +abstract class AuthorBlocksViewerRule(override val action: Action) + extends OnlyWhenNotAuthorViewerRule( + action, + Condition.AuthorBlocksViewer + ) + +object AuthorBlocksViewerDropRule + extends AuthorBlocksViewerRule( + Drop(Reason.AuthorBlocksViewer) + ) + +object DeciderableAuthorBlocksViewerDropRule + extends AuthorBlocksViewerRule( + Drop(Reason.AuthorBlocksViewer) + ) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableAuthorBlocksViewerDropRuleParam) +} + +object AuthorBlocksViewerTombstoneRule + extends AuthorBlocksViewerRule( + Tombstone(Epitaph.BlockedBy) + ) + +object ViewerBlocksAuthorRule + extends OnlyWhenNotAuthorViewerRule( + Drop(Reason.ViewerBlocksAuthor), + Condition.ViewerBlocksAuthor + ) + +object ViewerBlocksAuthorViewerOptInBlockingOnSearchRule + extends ViewerOptInBlockingOnSearchRule( + Drop(Reason.ViewerBlocksAuthor), + Condition.ViewerBlocksAuthor + ) + +object ViewerMutesAuthorRule + extends OnlyWhenNotAuthorViewerRule( + Drop(Reason.ViewerMutesAuthor), + Condition.ViewerMutesAuthor + ) + +object ViewerMutesAuthorViewerOptInBlockingOnSearchRule + extends ViewerOptInBlockingOnSearchRule( + Drop(Reason.ViewerMutesAuthor), + Condition.ViewerMutesAuthor + ) + +object AuthorBlocksOuterAuthorRule + extends RuleWithConstantAction( + Drop(Reason.AuthorBlocksViewer), + And(Not(Condition.IsSelfQuote), Condition.AuthorBlocksOuterAuthor) + ) + +object ViewerMutesAndDoesNotFollowAuthorRule + extends OnlyWhenNotAuthorViewerRule( + Drop(Reason.ViewerHardMutedAuthor), + And(Condition.ViewerMutesAuthor, Not(Condition.ViewerDoesFollowAuthor)) + ) + +object AuthorBlocksViewerUnspecifiedRule + extends OnlyWhenNotAuthorViewerRule( + Drop(Reason.Unspecified), + Condition.AuthorBlocksViewer + ) + +object ViewerHasMatchingMutedKeywordForNotificationsRule + extends OnlyWhenNotAuthorViewerRule( + Drop(Reason.MutedKeyword), + Condition.ViewerHasMatchingKeywordForNotifications + ) + +object ViewerHasMatchingMutedKeywordForHomeTimelineRule + extends OnlyWhenNotAuthorViewerRule( + Drop(Reason.MutedKeyword), + Condition.ViewerHasMatchingKeywordForHomeTimeline + ) + +trait HasPromotedTweetHealthEnforcement extends WithGate { + override def holdbacks: Seq[RuleParam[Boolean]] = Seq(PromotedTweetHealthEnforcementHoldback) + override def enabled: Seq[RuleParam[Boolean]] = Seq( + EnableTimelineHomePromotedTweetHealthEnforcementRules) +} + +object ViewerHasMatchingMutedKeywordForHomeTimelinePromotedTweetRule + extends OnlyWhenNotAuthorViewerRule( + Drop(Reason.MutedKeyword), + Condition.ViewerHasMatchingKeywordForHomeTimeline + ) + with HasPromotedTweetHealthEnforcement + +object ViewerHasMatchingMutedKeywordForTweetRepliesRule + extends OnlyWhenNotAuthorViewerRule( + Drop(Reason.MutedKeyword), + Condition.ViewerHasMatchingKeywordForTweetReplies + ) + +object MutedKeywordForTweetRepliesInterstitialRule + extends OnlyWhenNotAuthorViewerRule( + Interstitial(Reason.MutedKeyword), + Condition.ViewerHasMatchingKeywordForTweetReplies + ) + +object MutedKeywordForQuotedTweetTweetDetailInterstitialRule + extends OnlyWhenNotAuthorViewerRule( + Interstitial(Reason.MutedKeyword), + And(Condition.IsQuotedInnerTweet, Condition.ViewerHasMatchingKeywordForTweetReplies) + ) + +object ViewerMutesAuthorInterstitialRule + extends OnlyWhenNotAuthorViewerRule( + Interstitial(Reason.ViewerMutesAuthor), + Condition.ViewerMutesAuthor + ) + +object ViewerMutesAuthorInnerQuotedTweetInterstitialRule + extends OnlyWhenNotAuthorViewerRule( + Interstitial(Reason.ViewerMutesAuthor), + And(Condition.ViewerMutesAuthor, IsQuotedInnerTweet) + ) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableInnerQuotedTweetViewerMutesAuthorInterstitialRuleParam) +} + +object ViewerMutesAuthorHomeTimelinePromotedTweetRule + extends OnlyWhenNotAuthorViewerRule( + Drop(Reason.ViewerMutesAuthor), + Condition.ViewerMutesAuthor + ) + with HasPromotedTweetHealthEnforcement + +object ViewerBlocksAuthorInterstitialRule + extends OnlyWhenNotAuthorViewerRule( + Interstitial(Reason.ViewerBlocksAuthor), + Condition.ViewerBlocksAuthor + ) + +object ViewerBlocksAuthorInnerQuotedTweetInterstitialRule + extends OnlyWhenNotAuthorViewerRule( + Interstitial(Reason.ViewerBlocksAuthor), + And(Condition.ViewerBlocksAuthor, IsQuotedInnerTweet) + ) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableInnerQuotedTweetViewerBlocksAuthorInterstitialRuleParam) +} + +object ViewerBlocksAuthorHomeTimelinePromotedTweetRule + extends OnlyWhenNotAuthorViewerRule( + Drop(Reason.ViewerBlocksAuthor), + Condition.ViewerBlocksAuthor + ) + with HasPromotedTweetHealthEnforcement + +object ViewerReportsAuthorInterstitialRule + extends OnlyWhenNotAuthorViewerRule( + Interstitial(Reason.ViewerReportedAuthor), + Condition.ViewerReportsAuthor + ) + +object ViewerIsAuthorDropRule + extends RuleWithConstantAction(Drop(Unspecified), Not(NonAuthorViewer)) + +object ViewerIsNotAuthorDropRule extends RuleWithConstantAction(Drop(Unspecified), NonAuthorViewer) + +object RetweetDropRule extends RuleWithConstantAction(Drop(Unspecified), Retweet) + +object ViewerIsSoftUserDropRule extends RuleWithConstantAction(Drop(ViewerIsSoftUser), SoftViewer) { + + override val enabled: Seq[RuleParam[Boolean]] = Seq(EnableViewerIsSoftUserDropRuleParam) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/SafeSearchRules.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/SafeSearchRules.scala new file mode 100644 index 000000000..838e78a5b --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/SafeSearchRules.scala @@ -0,0 +1,332 @@ +package com.twitter.visibility.rules + +import com.twitter.visibility.configapi.params.RuleParam +import com.twitter.visibility.configapi.params.RuleParams.EnableDownrankSpamReplySectioningRuleParam +import com.twitter.visibility.configapi.params.RuleParams.EnableNotGraduatedSearchDropRuleParam +import com.twitter.visibility.configapi.params.RuleParams.EnableNsfwTextSectioningRuleParam +import com.twitter.visibility.configapi.params.RuleParams.NotGraduatedUserLabelRuleHoldbackExperimentParam +import com.twitter.visibility.models.TweetSafetyLabelType +import com.twitter.visibility.models.UserLabelValue +import com.twitter.visibility.rules.Condition.And +import com.twitter.visibility.rules.Condition.LoggedOutOrViewerNotFollowingAuthor +import com.twitter.visibility.rules.Condition.LoggedOutOrViewerOptInFiltering +import com.twitter.visibility.rules.Condition.NonAuthorViewer +import com.twitter.visibility.rules.Condition.Not +import com.twitter.visibility.rules.Condition.TweetComposedBefore +import com.twitter.visibility.rules.Condition.ViewerDoesFollowAuthor +import com.twitter.visibility.rules.Condition.ViewerOptInFilteringOnSearch +import com.twitter.visibility.rules.Reason.Nsfw +import com.twitter.visibility.rules.Reason.Unspecified +import com.twitter.visibility.rules.RuleActionSourceBuilder.TweetSafetyLabelSourceBuilder + +case object SafeSearchTweetRules { + + object SafeSearchAbusiveTweetLabelRule + extends NonAuthorViewerOptInFilteringWithTweetLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.Abusive + ) + with DoesLogVerdict { + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(TweetSafetyLabelType.Abusive)) + } + + object SafeSearchNsfwHighPrecisionTweetLabelRule + extends NonAuthorViewerOptInFilteringWithTweetLabelRule( + Drop(Nsfw), + TweetSafetyLabelType.NsfwHighPrecision + ) + + object SafeSearchGoreAndViolenceHighPrecisionTweetLabelRule + extends NonAuthorViewerOptInFilteringWithTweetLabelRule( + Drop(Nsfw), + TweetSafetyLabelType.GoreAndViolenceHighPrecision + ) + + object SafeSearchNsfwReportedHeuristicsTweetLabelRule + extends NonAuthorViewerOptInFilteringWithTweetLabelRule( + Drop(Nsfw), + TweetSafetyLabelType.NsfwReportedHeuristics + ) + + object SafeSearchGoreAndViolenceReportedHeuristicsTweetLabelRule + extends NonAuthorViewerOptInFilteringWithTweetLabelRule( + Drop(Nsfw), + TweetSafetyLabelType.GoreAndViolenceReportedHeuristics + ) + + object SafeSearchNsfwCardImageTweetLabelRule + extends NonAuthorViewerOptInFilteringWithTweetLabelRule( + Drop(Nsfw), + TweetSafetyLabelType.NsfwCardImage + ) + + object SafeSearchNsfwHighRecallTweetLabelRule + extends NonAuthorViewerOptInFilteringWithTweetLabelRule( + Drop(Nsfw), + TweetSafetyLabelType.NsfwHighRecall + ) + + object SafeSearchNsfwVideoTweetLabelRule + extends NonAuthorViewerOptInFilteringWithTweetLabelRule( + Drop(Nsfw), + TweetSafetyLabelType.NsfwVideo + ) + + object SafeSearchNsfwTextTweetLabelRule + extends NonAuthorViewerOptInFilteringWithTweetLabelRule( + Drop(Nsfw), + TweetSafetyLabelType.NsfwText + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableNsfwTextSectioningRuleParam) + } + + object SafeSearchNsfwTextAuthorLabelRule + extends ViewerOptInFilteringOnSearchUserLabelRule( + Drop(Nsfw), + UserLabelValue.DownrankSpamReply + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableNsfwTextSectioningRuleParam) + } + + object SafeSearchGoreAndViolenceTweetLabelRule + extends ConditionWithTweetLabelRule( + Drop(Unspecified), + And( + NonAuthorViewer, + TweetComposedBefore(TweetSafetyLabelType.GoreAndViolence.DeprecatedAt), + LoggedOutOrViewerOptInFiltering + ), + TweetSafetyLabelType.GoreAndViolence + ) + + object SafeSearchUntrustedUrlTweetLabelRule + extends NonAuthorViewerOptInFilteringWithTweetLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.UntrustedUrl + ) + + object SafeSearchDownrankSpamReplyTweetLabelRule + extends NonAuthorViewerOptInFilteringWithTweetLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.DownrankSpamReply + ) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableDownrankSpamReplySectioningRuleParam) + } + + object SafeSearchDownrankSpamReplyAuthorLabelRule + extends ViewerOptInFilteringOnSearchUserLabelRule( + Drop(Unspecified), + UserLabelValue.DownrankSpamReply + ) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableDownrankSpamReplySectioningRuleParam) + } + + object SafeSearchAutomationNonFollowerTweetLabelRule + extends NonFollowerViewerOptInFilteringWithTweetLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.Automation + ) + + object SafeSearchDuplicateMentionNonFollowerTweetLabelRule + extends NonFollowerViewerOptInFilteringWithTweetLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.DuplicateMention + ) + + object SafeSearchBystanderAbusiveTweetLabelRule + extends NonAuthorViewerOptInFilteringWithTweetLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.BystanderAbusive + ) +} + +case object UnsafeSearchTweetRules { + + object UnsafeSearchNsfwHighPrecisionInterstitialAllUsersTweetLabelRule + extends ConditionWithTweetLabelRule( + Interstitial(Nsfw), + Not(ViewerOptInFilteringOnSearch), + TweetSafetyLabelType.NsfwHighPrecision + ) + + object UnsafeSearchGoreAndViolenceHighPrecisionAllUsersTweetLabelRule + extends ConditionWithTweetLabelRule( + Interstitial(Nsfw), + Not(ViewerOptInFilteringOnSearch), + TweetSafetyLabelType.GoreAndViolenceHighPrecision + ) + + object UnsafeSearchGoreAndViolenceHighPrecisionAllUsersTweetLabelDropRule + extends ConditionWithTweetLabelRule( + Drop(Nsfw), + Not(ViewerOptInFilteringOnSearch), + TweetSafetyLabelType.GoreAndViolenceHighPrecision + ) + + object UnsafeSearchNsfwReportedHeuristicsAllUsersTweetLabelRule + extends ConditionWithTweetLabelRule( + Interstitial(Nsfw), + Not(ViewerOptInFilteringOnSearch), + TweetSafetyLabelType.NsfwReportedHeuristics + ) + + object UnsafeSearchNsfwReportedHeuristicsAllUsersTweetLabelDropRule + extends ConditionWithTweetLabelRule( + Drop(Nsfw), + Not(ViewerOptInFilteringOnSearch), + TweetSafetyLabelType.NsfwReportedHeuristics + ) + + object UnsafeSearchNsfwHighPrecisionAllUsersTweetLabelDropRule + extends ConditionWithTweetLabelRule( + Drop(Nsfw), + Not(ViewerOptInFilteringOnSearch), + TweetSafetyLabelType.NsfwHighPrecision + ) + + object UnsafeSearchGoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule + extends ConditionWithTweetLabelRule( + Interstitial(Nsfw), + Not(ViewerOptInFilteringOnSearch), + TweetSafetyLabelType.GoreAndViolenceReportedHeuristics + ) + + object UnsafeSearchGoreAndViolenceReportedHeuristicsAllUsersTweetLabelDropRule + extends ConditionWithTweetLabelRule( + Drop(Nsfw), + Not(ViewerOptInFilteringOnSearch), + TweetSafetyLabelType.GoreAndViolenceReportedHeuristics + ) + + object UnsafeSearchNsfwCardImageAllUsersTweetLabelRule + extends ConditionWithTweetLabelRule( + Interstitial(Nsfw), + Not(ViewerOptInFilteringOnSearch), + TweetSafetyLabelType.NsfwCardImage + ) + + object UnsafeSearchNsfwCardImageAllUsersTweetLabelDropRule + extends ConditionWithTweetLabelRule( + Drop(Nsfw), + Not(ViewerOptInFilteringOnSearch), + TweetSafetyLabelType.NsfwCardImage + ) + +} + +case object SafeSearchUserRules { + + object SafeSearchAbusiveUserLabelRule + extends ViewerOptInFilteringOnSearchUserLabelRule( + Drop(Unspecified), + UserLabelValue.Abusive + ) + + object SafeSearchAbusiveHighRecallUserLabelRule + extends ViewerOptInFilteringOnSearchUserLabelRule( + Drop(Unspecified), + UserLabelValue.AbusiveHighRecall, + LoggedOutOrViewerNotFollowingAuthor + ) + + object SafeSearchHighRecallUserLabelRule + extends ViewerOptInFilteringOnSearchUserLabelRule( + Drop(Nsfw), + UserLabelValue.NsfwHighRecall, + LoggedOutOrViewerNotFollowingAuthor + ) + + object SafeSearchCompromisedUserLabelRule + extends ViewerOptInFilteringOnSearchUserLabelRule( + Drop(Unspecified), + UserLabelValue.Compromised + ) + + object SafeSearchDuplicateContentUserLabelRule + extends ViewerOptInFilteringOnSearchUserLabelRule( + Drop(Unspecified), + UserLabelValue.DuplicateContent + ) + + object SafeSearchLowQualityUserLabelRule + extends ViewerOptInFilteringOnSearchUserLabelRule( + Drop(Unspecified), + UserLabelValue.LowQuality + ) + + object SafeSearchNsfwHighPrecisionUserLabelRule + extends ConditionWithUserLabelRule( + Drop(Nsfw), + LoggedOutOrViewerOptInFiltering, + UserLabelValue.NsfwHighPrecision + ) + + object SafeSearchNsfwAvatarImageUserLabelRule + extends ViewerOptInFilteringOnSearchUserLabelRule( + Drop(Nsfw), + UserLabelValue.NsfwAvatarImage + ) + + object SafeSearchNsfwBannerImageUserLabelRule + extends ViewerOptInFilteringOnSearchUserLabelRule( + Drop(Nsfw), + UserLabelValue.NsfwBannerImage + ) + + object SafeSearchNsfwNearPerfectAuthorRule + extends ViewerOptInFilteringOnSearchUserLabelRule( + Drop(Nsfw), + UserLabelValue.NsfwNearPerfect + ) + + object SafeSearchReadOnlyUserLabelRule + extends ViewerOptInFilteringOnSearchUserLabelRule( + Drop(Unspecified), + UserLabelValue.ReadOnly + ) + + object SafeSearchSpamHighRecallUserLabelRule + extends ViewerOptInFilteringOnSearchUserLabelRule( + Drop(Unspecified), + UserLabelValue.SpamHighRecall + ) + + object SafeSearchSearchBlacklistUserLabelRule + extends ViewerOptInFilteringOnSearchUserLabelRule( + Drop(Unspecified), + UserLabelValue.SearchBlacklist + ) + + object SafeSearchNsfwTextUserLabelRule + extends ViewerOptInFilteringOnSearchUserLabelRule( + Drop(Unspecified), + UserLabelValue.NsfwText + ) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableNsfwTextSectioningRuleParam) + } + + object SafeSearchDoNotAmplifyNonFollowersUserLabelRule + extends ViewerOptInFilteringOnSearchUserLabelRule( + Drop(Unspecified), + UserLabelValue.DoNotAmplify, + prerequisiteCondition = Not(ViewerDoesFollowAuthor) + ) + + object SafeSearchNotGraduatedNonFollowersUserLabelRule + extends ViewerOptInFilteringOnSearchUserLabelRule( + Drop(Unspecified), + UserLabelValue.NotGraduated, + prerequisiteCondition = Not(ViewerDoesFollowAuthor) + ) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableNotGraduatedSearchDropRuleParam) + + override def holdbacks: Seq[RuleParam[Boolean]] = + Seq(NotGraduatedUserLabelRuleHoldbackExperimentParam) + + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/SearchBlenderRules.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/SearchBlenderRules.scala new file mode 100644 index 000000000..8e7c54582 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/SearchBlenderRules.scala @@ -0,0 +1,37 @@ +package com.twitter.visibility.rules + +import com.twitter.visibility.models.TweetSafetyLabelType +import com.twitter.visibility.rules.Condition.And +import com.twitter.visibility.rules.Condition.HasSearchCandidateCountGreaterThan45 +import com.twitter.visibility.rules.Condition.IsFirstPageSearchResult +import com.twitter.visibility.rules.Condition.Not +import com.twitter.visibility.rules.Reason.FirstPageSearchResult + +abstract class FirstPageSearchResultWithTweetLabelRule( + action: Action, + tweetSafetyLabelType: TweetSafetyLabelType) + extends ConditionWithTweetLabelRule( + action, + IsFirstPageSearchResult, + tweetSafetyLabelType + ) + +abstract class FirstPageSearchResultSmartOutOfNetworkWithTweetLabelRule( + action: Action, + tweetSafetyLabelType: TweetSafetyLabelType) + extends ConditionWithTweetLabelRule( + action, + And( + IsFirstPageSearchResult, + HasSearchCandidateCountGreaterThan45, + Condition.NonAuthorViewer, + Not(Condition.ViewerDoesFollowAuthor), + Not(Condition.VerifiedAuthor) + ), + tweetSafetyLabelType + ) + +object FirstPageSearchResultAgathaSpamDropRule + extends FirstPageSearchResultWithTweetLabelRule( + Drop(FirstPageSearchResult), + TweetSafetyLabelType.AgathaSpam) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/SensitiveMediaSettingsRules.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/SensitiveMediaSettingsRules.scala new file mode 100644 index 000000000..2fe315afe --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/SensitiveMediaSettingsRules.scala @@ -0,0 +1,277 @@ +package com.twitter.visibility.rules + +import com.twitter.visibility.rules.Condition.ViewerHasAdultMediaSettingLevel +import com.twitter.visibility.rules.Condition.ViewerHasViolentMediaSettingLevel +import com.twitter.visibility.rules.Condition.ViewerHasOtherSensitiveMediaSettingLevel +import com.twitter.visibility.rules.Condition.LoggedInViewer +import com.twitter.visibility.rules.Condition.LoggedOutViewer +import com.twitter.visibility.rules.Condition.TweetHasNsfwUserAuthor +import com.twitter.visibility.rules.Condition.TweetHasNsfwAdminAuthor +import com.twitter.visibility.rules.Condition.And +import com.twitter.visibility.rules.Condition.Or +import com.twitter.visibility.rules.Condition.Not +import com.twitter.visibility.rules.Condition.NonAuthorViewer +import com.twitter.visibility.rules.Condition.TweetHasMedia +import com.twitter.visibility.rules.Reason.Nsfw +import com.twitter.visibility.models.TweetSafetyLabelType +import com.twitter.contenthealth.sensitivemediasettings.thriftscala.SensitiveMediaSettingsLevel + + +abstract class AdultMediaTweetLabelDropRule(tweetSafetyLabelType: TweetSafetyLabelType) + extends ConditionWithTweetLabelRule( + Drop(Nsfw), + And(LoggedInViewer, ViewerHasAdultMediaSettingLevel(SensitiveMediaSettingsLevel.Drop)), + tweetSafetyLabelType + ) + +abstract class ViolentMediaTweetLabelDropRule(tweetSafetyLabelType: TweetSafetyLabelType) + extends ConditionWithTweetLabelRule( + Drop(Nsfw), + And(LoggedInViewer, ViewerHasViolentMediaSettingLevel(SensitiveMediaSettingsLevel.Drop)), + tweetSafetyLabelType + ) + +abstract class OtherSensitiveMediaTweetLabelDropRule(condition: Condition) + extends RuleWithConstantAction( + Drop(Nsfw), + And( + condition, + And( + TweetHasMedia, + LoggedInViewer, + ViewerHasOtherSensitiveMediaSettingLevel(SensitiveMediaSettingsLevel.Drop))) + ) + +abstract class AdultMediaTweetLabelInterstitialRule(tweetSafetyLabelType: TweetSafetyLabelType) + extends ConditionWithTweetLabelRule( + Interstitial(Nsfw), + Or( + LoggedOutViewer, + ViewerHasAdultMediaSettingLevel(SensitiveMediaSettingsLevel.Warn), + Not(ViewerHasAdultMediaSettingLevel(SensitiveMediaSettingsLevel.Allow)) + ), + tweetSafetyLabelType + ) + +abstract class ViolentMediaTweetLabelInterstitialRule(tweetSafetyLabelType: TweetSafetyLabelType) + extends ConditionWithTweetLabelRule( + Interstitial(Nsfw), + Or( + LoggedOutViewer, + ViewerHasViolentMediaSettingLevel(SensitiveMediaSettingsLevel.Warn), + Not(ViewerHasViolentMediaSettingLevel(SensitiveMediaSettingsLevel.Allow)) + ), + tweetSafetyLabelType + ) + +abstract class OtherSensitiveMediaTweetLabelInterstitialRule(condition: Condition) + extends RuleWithConstantAction( + Interstitial(Nsfw), + And( + condition, + TweetHasMedia, + Or( + LoggedOutViewer, + ViewerHasOtherSensitiveMediaSettingLevel(SensitiveMediaSettingsLevel.Warn) + ) + ) + ) + +abstract class AdultMediaTweetLabelDropSettingLevelTombstoneRule( + tweetSafetyLabelType: TweetSafetyLabelType) + extends ConditionWithTweetLabelRule( + Tombstone(Epitaph.AdultMedia), + And( + LoggedInViewer, + NonAuthorViewer, + ViewerHasAdultMediaSettingLevel(SensitiveMediaSettingsLevel.Drop)), + tweetSafetyLabelType + ) + +abstract class ViolentMediaTweetLabelDropSettingLevelTombstoneRule( + tweetSafetyLabelType: TweetSafetyLabelType) + extends ConditionWithTweetLabelRule( + Tombstone(Epitaph.ViolentMedia), + And( + LoggedInViewer, + NonAuthorViewer, + ViewerHasViolentMediaSettingLevel(SensitiveMediaSettingsLevel.Drop)), + tweetSafetyLabelType + ) + +abstract class OtherSensitiveMediaTweetLabelDropSettingLevelTombstoneRule(condition: Condition) + extends RuleWithConstantAction( + Tombstone(Epitaph.OtherSensitiveMedia), + And( + condition, + And( + TweetHasMedia, + LoggedInViewer, + NonAuthorViewer, + ViewerHasOtherSensitiveMediaSettingLevel(SensitiveMediaSettingsLevel.Drop)) + ) + ) + +case object SensitiveMediaTweetDropRules { + + + object AdultMediaNsfwHighPrecisionTweetLabelDropRule + extends AdultMediaTweetLabelDropRule( + TweetSafetyLabelType.NsfwHighPrecision + ) + + object AdultMediaNsfwCardImageTweetLabelDropRule + extends AdultMediaTweetLabelDropRule( + TweetSafetyLabelType.NsfwCardImage + ) + + object AdultMediaNsfwReportedHeuristicsTweetLabelDropRule + extends AdultMediaTweetLabelDropRule( + TweetSafetyLabelType.NsfwReportedHeuristics + ) + + object AdultMediaNsfwVideoTweetLabelDropRule + extends AdultMediaTweetLabelDropRule( + TweetSafetyLabelType.NsfwVideo + ) + + object AdultMediaNsfwHighRecallTweetLabelDropRule + extends AdultMediaTweetLabelDropRule( + TweetSafetyLabelType.NsfwHighRecall + ) + + object AdultMediaNsfwTextTweetLabelDropRule + extends AdultMediaTweetLabelDropRule( + TweetSafetyLabelType.NsfwText + ) + + object ViolentMediaGoreAndViolenceHighPrecisionDropRule + extends ViolentMediaTweetLabelDropRule( + TweetSafetyLabelType.GoreAndViolenceHighPrecision + ) + + object ViolentMediaGoreAndViolenceReportedHeuristicsDropRule + extends ViolentMediaTweetLabelDropRule( + TweetSafetyLabelType.GoreAndViolenceReportedHeuristics + ) + + object OtherSensitiveMediaNsfwUserTweetFlagDropRule + extends OtherSensitiveMediaTweetLabelDropRule( + TweetHasNsfwUserAuthor + ) + + object OtherSensitiveMediaNsfwAdminTweetFlagDropRule + extends OtherSensitiveMediaTweetLabelDropRule( + TweetHasNsfwAdminAuthor + ) +} + +case object SensitiveMediaTweetInterstitialRules { + + object AdultMediaNsfwHighPrecisionTweetLabelInterstitialRule + extends AdultMediaTweetLabelInterstitialRule( + TweetSafetyLabelType.NsfwHighPrecision + ) + with DoesLogVerdict + + object AdultMediaNsfwCardImageTweetLabelInterstitialRule + extends AdultMediaTweetLabelInterstitialRule( + TweetSafetyLabelType.NsfwCardImage + ) + + object AdultMediaNsfwReportedHeuristicsTweetLabelInterstitialRule + extends AdultMediaTweetLabelInterstitialRule( + TweetSafetyLabelType.NsfwReportedHeuristics + ) + + object AdultMediaNsfwVideoTweetLabelInterstitialRule + extends AdultMediaTweetLabelInterstitialRule( + TweetSafetyLabelType.NsfwVideo + ) + + object AdultMediaNsfwHighRecallTweetLabelInterstitialRule + extends AdultMediaTweetLabelInterstitialRule( + TweetSafetyLabelType.NsfwHighRecall + ) + + object AdultMediaNsfwTextTweetLabelInterstitialRule + extends AdultMediaTweetLabelInterstitialRule( + TweetSafetyLabelType.NsfwText + ) + + object ViolentMediaGoreAndViolenceHighPrecisionInterstitialRule + extends ViolentMediaTweetLabelInterstitialRule( + TweetSafetyLabelType.GoreAndViolenceHighPrecision + ) + with DoesLogVerdict + + object ViolentMediaGoreAndViolenceReportedHeuristicsInterstitialRule + extends ViolentMediaTweetLabelInterstitialRule( + TweetSafetyLabelType.GoreAndViolenceReportedHeuristics + ) + + object OtherSensitiveMediaNsfwUserTweetFlagInterstitialRule + extends OtherSensitiveMediaTweetLabelInterstitialRule( + TweetHasNsfwUserAuthor + ) + + object OtherSensitiveMediaNsfwAdminTweetFlagInterstitialRule + extends OtherSensitiveMediaTweetLabelInterstitialRule( + TweetHasNsfwAdminAuthor + ) + +} + +case object SensitiveMediaTweetDropSettingLevelTombstoneRules { + + + object AdultMediaNsfwHighPrecisionTweetLabelDropSettingLevelTombstoneRule + extends AdultMediaTweetLabelDropSettingLevelTombstoneRule( + TweetSafetyLabelType.NsfwHighPrecision + ) + + object AdultMediaNsfwCardImageTweetLabelDropSettingLevelTombstoneRule + extends AdultMediaTweetLabelDropSettingLevelTombstoneRule( + TweetSafetyLabelType.NsfwCardImage + ) + + object AdultMediaNsfwReportedHeuristicsTweetLabelDropSettingLevelTombstoneRule + extends AdultMediaTweetLabelDropSettingLevelTombstoneRule( + TweetSafetyLabelType.NsfwReportedHeuristics + ) + + object AdultMediaNsfwVideoTweetLabelDropSettingLevelTombstoneRule + extends AdultMediaTweetLabelDropSettingLevelTombstoneRule( + TweetSafetyLabelType.NsfwVideo + ) + + object AdultMediaNsfwHighRecallTweetLabelDropSettingLevelTombstoneRule + extends AdultMediaTweetLabelDropSettingLevelTombstoneRule( + TweetSafetyLabelType.NsfwHighRecall + ) + + object AdultMediaNsfwTextTweetLabelDropSettingLevelTombstoneRule + extends AdultMediaTweetLabelDropSettingLevelTombstoneRule( + TweetSafetyLabelType.NsfwText + ) + + object ViolentMediaGoreAndViolenceHighPrecisionDropSettingLeveTombstoneRule + extends ViolentMediaTweetLabelDropSettingLevelTombstoneRule( + TweetSafetyLabelType.GoreAndViolenceHighPrecision + ) + + object ViolentMediaGoreAndViolenceReportedHeuristicsDropSettingLevelTombstoneRule + extends ViolentMediaTweetLabelDropSettingLevelTombstoneRule( + TweetSafetyLabelType.GoreAndViolenceReportedHeuristics + ) + + object OtherSensitiveMediaNsfwUserTweetFlagDropSettingLevelTombstoneRule + extends OtherSensitiveMediaTweetLabelDropSettingLevelTombstoneRule( + TweetHasNsfwUserAuthor + ) + + object OtherSensitiveMediaNsfwAdminTweetFlagDropSettingLevelTombstoneRule + extends OtherSensitiveMediaTweetLabelDropSettingLevelTombstoneRule( + TweetHasNsfwAdminAuthor + ) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/SpaceRules.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/SpaceRules.scala new file mode 100644 index 000000000..a6d771d27 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/SpaceRules.scala @@ -0,0 +1,219 @@ +package com.twitter.visibility.rules + +import com.twitter.visibility.configapi.params.FSRuleParams.HighToxicityModelScoreSpaceThresholdParam +import com.twitter.visibility.configapi.params.RuleParam +import com.twitter.visibility.configapi.params.RuleParams.EnableMutedKeywordFilteringSpaceTitleNotificationsRuleParam +import com.twitter.visibility.models.SpaceSafetyLabelType.CoordinatedHarmfulActivityHighRecall +import com.twitter.visibility.models.SpaceSafetyLabelType.DoNotAmplify +import com.twitter.visibility.models.SpaceSafetyLabelType.MisleadingHighRecall +import com.twitter.visibility.models.SpaceSafetyLabelType.NsfwHighPrecision +import com.twitter.visibility.models.SpaceSafetyLabelType.NsfwHighRecall +import com.twitter.visibility.models.SpaceSafetyLabelType.UntrustedUrl +import com.twitter.visibility.models.UserLabelValue.Abusive +import com.twitter.visibility.models.UserLabelValue.BlinkWorst +import com.twitter.visibility.models.UserLabelValue.DelayedRemediation +import com.twitter.visibility.models.UserLabelValue.NsfwAvatarImage +import com.twitter.visibility.models.UserLabelValue.NsfwBannerImage +import com.twitter.visibility.models.UserLabelValue.NsfwNearPerfect +import com.twitter.visibility.models.SpaceSafetyLabelType +import com.twitter.visibility.models.SpaceSafetyLabelType.HatefulHighRecall +import com.twitter.visibility.models.SpaceSafetyLabelType.HighToxicityModelScore +import com.twitter.visibility.models.SpaceSafetyLabelType.ViolenceHighRecall +import com.twitter.visibility.models.UserLabelValue +import com.twitter.visibility.rules.Condition._ +import com.twitter.visibility.rules.Reason.Nsfw +import com.twitter.visibility.rules.Reason.Unspecified + +object SpaceRules { + + abstract class SpaceHasLabelRule( + action: Action, + safetyLabelType: SpaceSafetyLabelType) + extends RuleWithConstantAction(action, And(SpaceHasLabel(safetyLabelType), NonAuthorViewer)) + + abstract class SpaceHasLabelAndNonFollowerRule( + action: Action, + safetyLabelType: SpaceSafetyLabelType) + extends RuleWithConstantAction( + action, + And(SpaceHasLabel(safetyLabelType), LoggedOutOrViewerNotFollowingAuthor)) + + abstract class AnySpaceHostOrAdminHasLabelRule( + action: Action, + userLabel: UserLabelValue) + extends WhenAuthorUserLabelPresentRule(action, userLabel) + + abstract class AnySpaceHostOrAdminHasLabelAndNonFollowerRule( + action: Action, + userLabel: UserLabelValue) + extends ConditionWithUserLabelRule(action, LoggedOutOrViewerNotFollowingAuthor, userLabel) + + + object SpaceDoNotAmplifyAllUsersDropRule + extends SpaceHasLabelRule( + Drop(Unspecified), + DoNotAmplify, + ) + + object SpaceDoNotAmplifyNonFollowerDropRule + extends SpaceHasLabelAndNonFollowerRule( + Drop(Unspecified), + DoNotAmplify, + ) + + object SpaceCoordHarmfulActivityHighRecallAllUsersDropRule + extends SpaceHasLabelRule( + Drop(Unspecified), + CoordinatedHarmfulActivityHighRecall, + ) + + object SpaceCoordHarmfulActivityHighRecallNonFollowerDropRule + extends SpaceHasLabelAndNonFollowerRule( + Drop(Unspecified), + CoordinatedHarmfulActivityHighRecall, + ) + + object SpaceUntrustedUrlAllUsersDropRule + extends SpaceHasLabelRule( + Drop(Unspecified), + UntrustedUrl, + ) + + object SpaceUntrustedUrlNonFollowerDropRule + extends SpaceHasLabelAndNonFollowerRule( + Drop(Unspecified), + UntrustedUrl, + ) + + object SpaceMisleadingHighRecallNonFollowerDropRule + extends SpaceHasLabelAndNonFollowerRule( + Drop(Unspecified), + MisleadingHighRecall, + ) + + object SpaceNsfwHighPrecisionAllUsersInterstitialRule + extends SpaceHasLabelRule( + Interstitial(Nsfw), + NsfwHighPrecision, + ) + + object SpaceNsfwHighPrecisionAllUsersDropRule + extends SpaceHasLabelRule( + Drop(Nsfw), + NsfwHighPrecision, + ) + + object SpaceNsfwHighPrecisionNonFollowerDropRule + extends SpaceHasLabelAndNonFollowerRule( + Drop(Nsfw), + NsfwHighPrecision, + ) + + object SpaceNsfwHighPrecisionSafeSearchNonFollowerDropRule + extends RuleWithConstantAction( + Drop(Nsfw), + And( + SpaceHasLabel(NsfwHighPrecision), + NonAuthorViewer, + LoggedOutOrViewerOptInFiltering, + Not(ViewerDoesFollowAuthor), + ), + ) + + object SpaceNsfwHighRecallAllUsersDropRule + extends SpaceHasLabelRule( + Drop(Nsfw), + NsfwHighRecall, + ) + + object SpaceNsfwHighRecallNonFollowerDropRule + extends SpaceHasLabelAndNonFollowerRule( + Drop(Nsfw), + NsfwHighRecall, + ) + + object SpaceNsfwHighRecallSafeSearchNonFollowerDropRule + extends RuleWithConstantAction( + Drop(Nsfw), + And( + SpaceHasLabel(NsfwHighRecall), + NonAuthorViewer, + LoggedOutOrViewerOptInFiltering, + Not(ViewerDoesFollowAuthor), + ), + ) + + object SpaceHatefulHighRecallAllUsersDropRule + extends SpaceHasLabelRule( + Drop(Unspecified), + HatefulHighRecall, + ) + + object SpaceViolenceHighRecallAllUsersDropRule + extends SpaceHasLabelRule( + Drop(Unspecified), + ViolenceHighRecall, + ) + + object SpaceHighToxicityScoreNonFollowerDropRule + extends RuleWithConstantAction( + Drop(Unspecified), + And( + SpaceHasLabelWithScoreAboveThresholdWithParam( + HighToxicityModelScore, + HighToxicityModelScoreSpaceThresholdParam + ), + NonAuthorViewer, + LoggedOutOrViewerNotFollowingAuthor, + ) + ) + with ExperimentalRule + + + object ViewerHasMatchingMutedKeywordInSpaceTitleForNotificationsRule + extends OnlyWhenNotAuthorViewerRule( + Drop(Reason.MutedKeyword), + Condition.ViewerHasMatchingKeywordInSpaceTitleForNotifications + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq( + EnableMutedKeywordFilteringSpaceTitleNotificationsRuleParam) + + } + + + object UserAbusiveNonFollowerDropRule + extends AnySpaceHostOrAdminHasLabelAndNonFollowerRule( + Drop(Unspecified), + Abusive + ) + + object UserBlinkWorstAllUsersDropRule + extends AnySpaceHostOrAdminHasLabelRule( + Drop(Unspecified), + BlinkWorst + ) + + object UserNsfwNearPerfectNonFollowerDropRule + extends AnySpaceHostOrAdminHasLabelAndNonFollowerRule( + Drop(Nsfw), + NsfwNearPerfect + ) + + object UserNsfwHighPrecisionNonFollowerDropRule + extends AnySpaceHostOrAdminHasLabelAndNonFollowerRule( + Drop(Nsfw), + UserLabelValue.NsfwHighPrecision + ) + + object UserNsfwAvatarImageNonFollowerDropRule + extends AnySpaceHostOrAdminHasLabelAndNonFollowerRule( + Drop(Nsfw), + NsfwAvatarImage + ) + + object UserNsfwBannerImageNonFollowerDropRule + extends AnySpaceHostOrAdminHasLabelAndNonFollowerRule( + Drop(Nsfw), + NsfwBannerImage + ) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/TombstoneIf.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/TombstoneIf.scala new file mode 100644 index 000000000..38f9276ff --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/TombstoneIf.scala @@ -0,0 +1,44 @@ +package com.twitter.visibility.rules + +import com.twitter.visibility.rules.Condition.And +import com.twitter.visibility.rules.Condition.IsFocalTweet +import com.twitter.visibility.rules.Condition.Not + +object TombstoneIf { + + object AuthorIsProtected + extends RuleWithConstantAction( + Tombstone(Epitaph.Protected), + And( + Condition.LoggedOutOrViewerNotFollowingAuthor, + Condition.ProtectedAuthor + ) + ) + + object ReplyIsModeratedByRootAuthor + extends RuleWithConstantAction( + Tombstone(Epitaph.Moderated), + And( + Not(IsFocalTweet), + Condition.Moderated + ) + ) + + object ViewerIsBlockedByAuthor + extends OnlyWhenNotAuthorViewerRule( + Tombstone(Epitaph.BlockedBy), + Condition.AuthorBlocksViewer + ) + + object AuthorIsDeactivated + extends RuleWithConstantAction( + Tombstone(Epitaph.Deactivated), + Condition.DeactivatedAuthor + ) + + object AuthorIsSuspended + extends RuleWithConstantAction( + Tombstone(Epitaph.Suspended), + Condition.SuspendedAuthor + ) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/ToxicityReplyFilterRules.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/ToxicityReplyFilterRules.scala new file mode 100644 index 000000000..01a1ab393 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/ToxicityReplyFilterRules.scala @@ -0,0 +1,28 @@ +package com.twitter.visibility.rules + +import com.twitter.visibility.configapi.params.RuleParam +import com.twitter.visibility.configapi.params.RuleParams +import com.twitter.visibility.rules.Reason.Toxicity + +object ToxicityReplyFilterRules { + + sealed abstract class ToxicityReplyFilterBaseRule( + action: Action) + extends RuleWithConstantAction( + action = action, + condition = Condition.ToxrfFilteredFromAuthorViewer) + + object ToxicityReplyFilterRule + extends ToxicityReplyFilterBaseRule(action = Tombstone(Epitaph.Unavailable)) { + + override def enabled: Seq[RuleParam[Boolean]] = Seq( + RuleParams.EnableToxicReplyFilteringConversationRulesParam) + } + + object ToxicityReplyFilterDropNotificationRule + extends ToxicityReplyFilterBaseRule(action = Drop(Toxicity)) { + + override def enabled: Seq[RuleParam[Boolean]] = Seq( + RuleParams.EnableToxicReplyFilteringNotificationsRulesParam) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/TweetLabelRules.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/TweetLabelRules.scala new file mode 100644 index 000000000..11f2ef7f5 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/TweetLabelRules.scala @@ -0,0 +1,862 @@ +package com.twitter.visibility.rules + +import com.twitter.visibility.common.ModelScoreThresholds +import com.twitter.visibility.common.actions.AvoidReason +import com.twitter.visibility.common.actions.AvoidReason.MightNotBeSuitableForAds +import com.twitter.visibility.common.actions.LimitedEngagementReason +import com.twitter.visibility.common.actions.TweetVisibilityNudgeReason +import com.twitter.visibility.configapi.configs.DeciderKey +import com.twitter.visibility.configapi.params.FSRuleParams.HighSpammyTweetContentScoreSearchLatestProdTweetLabelDropRuleThresholdParam +import com.twitter.visibility.configapi.params.FSRuleParams.HighSpammyTweetContentScoreSearchTopProdTweetLabelDropRuleThresholdParam +import com.twitter.visibility.configapi.params.FSRuleParams.HighSpammyTweetContentScoreTrendLatestTweetLabelDropRuleThresholdParam +import com.twitter.visibility.configapi.params.FSRuleParams.HighSpammyTweetContentScoreTrendTopTweetLabelDropRuleThresholdParam +import com.twitter.visibility.configapi.params.FSRuleParams.SkipTweetDetailLimitedEngagementRuleEnabledParam +import com.twitter.visibility.configapi.params.RuleParam +import com.twitter.visibility.configapi.params.RuleParams._ +import com.twitter.visibility.models.TweetSafetyLabelType +import com.twitter.visibility.rules.Condition._ +import com.twitter.visibility.rules.Condition.{True => TrueCondition} +import com.twitter.visibility.rules.Reason._ +import com.twitter.visibility.rules.RuleActionSourceBuilder.TweetSafetyLabelSourceBuilder + +object AbusiveTweetLabelRule + extends NonAuthorWithTweetLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.Abusive + ) + with DoesLogVerdict + +object AbusiveNonFollowerTweetLabelRule + extends NonFollowerWithTweetLabelRule( + Drop(Toxicity), + TweetSafetyLabelType.Abusive + ) + +object AbusiveUqfNonFollowerTweetLabelRule + extends NonFollowerWithUqfTweetLabelRule( + Drop(Toxicity), + TweetSafetyLabelType.Abusive + ) + +object AbusiveHighRecallTweetLabelRule + extends NonAuthorWithTweetLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.AbusiveHighRecall + ) + +object AbusiveHighRecallNonFollowerTweetLabelRule + extends NonFollowerWithTweetLabelRule( + Interstitial(PossiblyUndesirable), + TweetSafetyLabelType.AbusiveHighRecall + ) + +object AutomationTweetLabelRule + extends NonFollowerWithTweetLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.Automation + ) + +object BystanderAbusiveTweetLabelRule + extends NonAuthorWithTweetLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.BystanderAbusive + ) + +object BystanderAbusiveNonFollowerTweetLabelRule + extends NonFollowerWithTweetLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.BystanderAbusive + ) + +abstract class DuplicateContentTweetLabelRule(action: Action) + extends NonAuthorWithTweetLabelRule( + action, + TweetSafetyLabelType.DuplicateContent + ) + +object DuplicateContentTweetLabelDropRule + extends DuplicateContentTweetLabelRule(Drop(TweetLabelDuplicateContent)) + +object DuplicateContentTweetLabelTombstoneRule + extends DuplicateContentTweetLabelRule(Tombstone(Epitaph.Unavailable)) + +object DuplicateMentionTweetLabelRule + extends NonFollowerWithTweetLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.DuplicateMention + ) + +object DuplicateMentionUqfTweetLabelRule + extends NonFollowerWithUqfTweetLabelRule( + Drop(TweetLabelDuplicateMention), + TweetSafetyLabelType.DuplicateMention + ) + +object GoreAndViolenceTweetLabelRule + extends ConditionWithTweetLabelRule( + Drop(Unspecified), + And( + NonAuthorViewer, + TweetComposedBefore(TweetSafetyLabelType.GoreAndViolence.DeprecatedAt) + ), + TweetSafetyLabelType.GoreAndViolence + ) + +object LiveLowQualityTweetLabelRule + extends NonAuthorWithTweetLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.LiveLowQuality + ) + +object LowQualityMentionTweetLabelRule + extends RuleWithConstantAction( + Drop(LowQualityMention), + And( + TweetHasLabelForPerspectivalUser(TweetSafetyLabelType.LowQualityMention), + ViewerHasUqfEnabled + ) + ) + +abstract class NsfwCardImageTweetLabelBaseRule( + override val action: Action, + val additionalCondition: Condition = TrueCondition, +) extends RuleWithConstantAction( + action, + And( + additionalCondition, + TweetHasLabel(TweetSafetyLabelType.NsfwCardImage) + ) + ) + +object NsfwCardImageTweetLabelRule + extends NsfwCardImageTweetLabelBaseRule( + action = Drop(Nsfw), + additionalCondition = NonAuthorViewer, + ) + +object NsfwCardImageAllUsersTweetLabelRule + extends NsfwCardImageTweetLabelBaseRule( + action = Interstitial(Nsfw) + ) + +object NsfwCardImageAvoidAllUsersTweetLabelRule + extends NsfwCardImageTweetLabelBaseRule( + action = Avoid(Some(AvoidReason.ContainsNsfwMedia)), + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableAvoidNsfwRulesParam) +} + +object NsfwCardImageAvoidAdPlacementAllUsersTweetLabelRule + extends NsfwCardImageTweetLabelBaseRule( + action = Avoid(Some(AvoidReason.ContainsNsfwMedia)), + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableAvoidNsfwRulesParam) +} + +object SearchAvoidTweetNsfwAdminRule + extends RuleWithConstantAction( + Avoid(Some(AvoidReason.ContainsNsfwMedia)), + TweetHasNsfwAdminAuthor + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableAvoidNsfwRulesParam) +} + +object SearchAvoidTweetNsfwUserRule + extends RuleWithConstantAction( + Avoid(Some(AvoidReason.ContainsNsfwMedia)), + TweetHasNsfwUserAuthor + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableAvoidNsfwRulesParam) +} + +object NsfwCardImageAllUsersTweetLabelDropRule + extends NsfwCardImageTweetLabelBaseRule( + action = Drop(Nsfw), + ) + +object HighProactiveTosScoreTweetLabelDropRule + extends NonAuthorWithTweetLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.HighProactiveTosScore + ) + +object HighProactiveTosScoreTweetLabelDropSearchRule + extends NonAuthorAndNonFollowerWithTweetLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.HighProactiveTosScore + ) + +object NsfwHighPrecisionTweetLabelRule + extends NonAuthorWithTweetLabelRule( + Drop(Nsfw), + TweetSafetyLabelType.NsfwHighPrecision + ) + +object NsfwHighPrecisionAllUsersTweetLabelDropRule + extends TweetHasLabelRule( + Drop(Nsfw), + TweetSafetyLabelType.NsfwHighPrecision + ) + +object NsfwHighPrecisionInnerQuotedTweetLabelRule + extends ConditionWithTweetLabelRule( + Drop(Nsfw), + And(IsQuotedInnerTweet, NonAuthorViewer), + TweetSafetyLabelType.NsfwHighPrecision + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableNsfwHpQuotedTweetDropRuleParam) +} + +object NsfwHighPrecisionTombstoneInnerQuotedTweetLabelRule + extends ConditionWithTweetLabelRule( + Tombstone(Epitaph.Unavailable), + And(IsQuotedInnerTweet, NonAuthorViewer), + TweetSafetyLabelType.NsfwHighPrecision + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableNsfwHpQuotedTweetTombstoneRuleParam) +} + +object GoreAndViolenceHighPrecisionTweetLabelRule + extends NonAuthorWithTweetLabelRule( + Drop(Nsfw), + TweetSafetyLabelType.GoreAndViolenceHighPrecision + ) + +object NsfwReportedHeuristicsTweetLabelRule + extends NonAuthorWithTweetLabelRule( + Drop(Nsfw), + TweetSafetyLabelType.NsfwReportedHeuristics + ) + +object GoreAndViolenceReportedHeuristicsTweetLabelRule + extends NonAuthorWithTweetLabelRule( + Drop(Nsfw), + TweetSafetyLabelType.GoreAndViolenceReportedHeuristics + ) + +object NsfwHighPrecisionInterstitialAllUsersTweetLabelRule + extends TweetHasLabelRule( + Interstitial(Nsfw), + TweetSafetyLabelType.NsfwHighPrecision + ) + with DoesLogVerdict + +object GoreAndViolenceHighPrecisionAvoidAllUsersTweetLabelRule + extends TweetHasLabelRule( + Avoid(Some(AvoidReason.ContainsNsfwMedia)), + TweetSafetyLabelType.GoreAndViolenceHighPrecision + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableAvoidNsfwRulesParam) +} + +object GoreAndViolenceHighPrecisionAllUsersTweetLabelRule + extends TweetHasLabelRule( + Interstitial(Nsfw), + TweetSafetyLabelType.GoreAndViolenceHighPrecision + ) + with DoesLogVerdict { + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(TweetSafetyLabelType.GoreAndViolenceHighPrecision) + ) +} + +object NsfwReportedHeuristicsAvoidAllUsersTweetLabelRule + extends TweetHasLabelRule( + Avoid(Some(AvoidReason.ContainsNsfwMedia)), + TweetSafetyLabelType.NsfwReportedHeuristics + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableAvoidNsfwRulesParam) +} + +object NsfwReportedHeuristicsAvoidAdPlacementAllUsersTweetLabelRule + extends TweetHasLabelRule( + Avoid(Some(AvoidReason.ContainsNsfwMedia)), + TweetSafetyLabelType.NsfwReportedHeuristics + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableAvoidNsfwRulesParam) +} + +object NsfwReportedHeuristicsAllUsersTweetLabelRule + extends TweetHasLabelRule( + Interstitial(Nsfw), + TweetSafetyLabelType.NsfwReportedHeuristics + ) + +object GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule + extends TweetHasLabelRule( + Interstitial(Nsfw), + TweetSafetyLabelType.GoreAndViolenceReportedHeuristics + ) + +object GoreAndViolenceReportedHeuristicsAvoidAllUsersTweetLabelRule + extends TweetHasLabelRule( + Avoid(Some(AvoidReason.ContainsNsfwMedia)), + TweetSafetyLabelType.GoreAndViolenceReportedHeuristics + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableAvoidNsfwRulesParam) +} + +object GoreAndViolenceReportedHeuristicsAvoidAdPlacementAllUsersTweetLabelRule + extends TweetHasLabelRule( + Avoid(Some(AvoidReason.ContainsNsfwMedia)), + TweetSafetyLabelType.GoreAndViolenceReportedHeuristics + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableAvoidNsfwRulesParam) +} + +object GoreAndViolenceHighPrecisionAllUsersTweetLabelDropRule + extends TweetHasLabelRule( + Drop(Nsfw), + TweetSafetyLabelType.GoreAndViolenceHighPrecision + ) + +object NsfwReportedHeuristicsAllUsersTweetLabelDropRule + extends TweetHasLabelRule( + Drop(Nsfw), + TweetSafetyLabelType.NsfwReportedHeuristics + ) + +object GoreAndViolenceReportedHeuristicsAllUsersTweetLabelDropRule + extends TweetHasLabelRule( + Drop(Nsfw), + TweetSafetyLabelType.GoreAndViolenceReportedHeuristics + ) + +object NsfwHighRecallTweetLabelRule + extends NonAuthorWithTweetLabelRule( + Drop(Nsfw), + TweetSafetyLabelType.NsfwHighRecall + ) + +object NsfwHighRecallAllUsersTweetLabelDropRule + extends TweetHasLabelRule( + Drop(Nsfw), + TweetSafetyLabelType.NsfwHighRecall + ) + +abstract class PdnaTweetLabelRule( + override val action: Action, + val additionalCondition: Condition) + extends ConditionWithTweetLabelRule( + action, + And(NonAuthorViewer, additionalCondition), + TweetSafetyLabelType.Pdna + ) + +object PdnaTweetLabelRule extends PdnaTweetLabelRule(Drop(PdnaTweet), Condition.True) + +object PdnaTweetLabelTombstoneRule + extends PdnaTweetLabelRule(Tombstone(Epitaph.Unavailable), Condition.True) + +object PdnaQuotedTweetLabelTombstoneRule + extends PdnaTweetLabelRule(Tombstone(Epitaph.Unavailable), Condition.IsQuotedInnerTweet) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnablePdnaQuotedTweetTombstoneRuleParam) +} + +object PdnaAllUsersTweetLabelRule + extends TweetHasLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.Pdna + ) + +object SearchBlacklistTweetLabelRule + extends NonAuthorWithTweetLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.SearchBlacklist + ) + +object SearchBlacklistHighRecallTweetLabelDropRule + extends NonAuthorAndNonFollowerWithTweetLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.SearchBlacklistHighRecall + ) + +abstract class SpamTweetLabelRule( + override val action: Action, + val additionalCondition: Condition) + extends ConditionWithTweetLabelRule( + action, + And(NonAuthorViewer, additionalCondition), + TweetSafetyLabelType.Spam + ) + with DoesLogVerdict + +object SpamTweetLabelRule extends SpamTweetLabelRule(Drop(TweetLabeledSpam), Condition.True) + +object SpamTweetLabelTombstoneRule + extends SpamTweetLabelRule(Tombstone(Epitaph.Unavailable), Condition.True) + +object SpamQuotedTweetLabelTombstoneRule + extends SpamTweetLabelRule(Tombstone(Epitaph.Unavailable), Condition.IsQuotedInnerTweet) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableSpamQuotedTweetTombstoneRuleParam) +} + +object SpamAllUsersTweetLabelRule + extends TweetHasLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.Spam + ) + +abstract class BounceTweetLabelRule(override val action: Action) + extends NonAuthorWithTweetLabelRule( + action, + TweetSafetyLabelType.Bounce + ) + +object BounceTweetLabelRule extends BounceTweetLabelRule(Drop(Bounce)) + +object BounceTweetLabelTombstoneRule extends BounceTweetLabelRule(Tombstone(Epitaph.Bounced)) + +abstract class BounceOuterTweetLabelRule(override val action: Action) + extends ConditionWithTweetLabelRule( + action, + And(Not(Condition.IsQuotedInnerTweet), NonAuthorViewer), + TweetSafetyLabelType.Bounce + ) + +object BounceOuterTweetTombstoneRule extends BounceOuterTweetLabelRule(Tombstone(Epitaph.Bounced)) + +object BounceQuotedTweetTombstoneRule + extends ConditionWithTweetLabelRule( + Tombstone(Epitaph.Bounced), + Condition.IsQuotedInnerTweet, + TweetSafetyLabelType.Bounce + ) + +object BounceAllUsersTweetLabelRule + extends TweetHasLabelRule( + Drop(Bounce), + TweetSafetyLabelType.Bounce + ) + + +abstract class SpamHighRecallTweetLabelRule(action: Action) + extends NonAuthorWithTweetLabelRule( + action, + TweetSafetyLabelType.SpamHighRecall + ) + +object SpamHighRecallTweetLabelDropRule + extends SpamHighRecallTweetLabelRule(Drop(SpamHighRecallTweet)) + +object SpamHighRecallTweetLabelTombstoneRule + extends SpamHighRecallTweetLabelRule(Tombstone(Epitaph.Unavailable)) + +object UntrustedUrlAllViewersTweetLabelRule + extends TweetHasLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.UntrustedUrl + ) + +object DownrankSpamReplyAllViewersTweetLabelRule + extends TweetHasLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.DownrankSpamReply + ) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableDownrankSpamReplySectioningRuleParam) +} + +object UntrustedUrlTweetLabelRule + extends NonAuthorWithTweetLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.UntrustedUrl + ) + +object DownrankSpamReplyTweetLabelRule + extends NonAuthorWithTweetLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.DownrankSpamReply + ) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableDownrankSpamReplySectioningRuleParam) +} + +object UntrustedUrlUqfNonFollowerTweetLabelRule + extends NonFollowerWithUqfTweetLabelRule( + Drop(UntrustedUrl), + TweetSafetyLabelType.UntrustedUrl + ) + +object DownrankSpamReplyUqfNonFollowerTweetLabelRule + extends NonFollowerWithUqfTweetLabelRule( + Drop(SpamReplyDownRank), + TweetSafetyLabelType.DownrankSpamReply + ) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableDownrankSpamReplySectioningRuleParam) +} + +object NsfaHighRecallTweetLabelRule + extends RuleWithConstantAction( + Drop(Unspecified), + And( + NonAuthorViewer, + TweetHasLabel(TweetSafetyLabelType.NsfaHighRecall) + ) + ) + +object NsfaHighRecallTweetLabelInterstitialRule + extends RuleWithConstantAction( + Interstitial(Unspecified), + And( + NonAuthorViewer, + TweetHasLabel(TweetSafetyLabelType.NsfaHighRecall) + ) + ) + +object NsfwVideoTweetLabelDropRule + extends NonAuthorWithTweetLabelRule( + Drop(Nsfw), + TweetSafetyLabelType.NsfwVideo + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableNsfwTextSectioningRuleParam) +} + +object NsfwTextTweetLabelDropRule + extends NonAuthorWithTweetLabelRule( + Drop(Nsfw), + TweetSafetyLabelType.NsfwText + ) + +object NsfwVideoAllUsersTweetLabelDropRule + extends TweetHasLabelRule( + Drop(Nsfw), + TweetSafetyLabelType.NsfwVideo + ) + +object NsfwTextAllUsersTweetLabelDropRule + extends TweetHasLabelRule( + Drop(Nsfw), + TweetSafetyLabelType.NsfwText + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableNsfwTextSectioningRuleParam) +} + +abstract class BaseLowQualityTweetLabelRule(action: Action) + extends RuleWithConstantAction( + action, + And( + TweetHasLabel(TweetSafetyLabelType.LowQuality), + TweetComposedBefore(PublicInterest.PolicyConfig.LowQualityProxyLabelStart), + NonAuthorViewer + ) + ) + with DoesLogVerdict { + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(TweetSafetyLabelType.LowQuality)) +} + +object LowQualityTweetLabelDropRule extends BaseLowQualityTweetLabelRule(Drop(LowQualityTweet)) + +object LowQualityTweetLabelTombstoneRule + extends BaseLowQualityTweetLabelRule(Tombstone(Epitaph.Unavailable)) + +abstract class SafetyCrisisLevelDropRule(level: Int, condition: Condition = TrueCondition) + extends ConditionWithTweetLabelRule( + Drop(Unspecified), + And( + NonAuthorViewer, + condition, + TweetHasSafetyLabelWithScoreEqInt(TweetSafetyLabelType.SafetyCrisis, level) + ), + TweetSafetyLabelType.SafetyCrisis + ) + +object SafetyCrisisAnyLevelDropRule + extends NonAuthorWithTweetLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.SafetyCrisis + ) + +object SafetyCrisisLevel2DropRule extends SafetyCrisisLevelDropRule(2, Not(ViewerDoesFollowAuthor)) + +object SafetyCrisisLevel3DropRule extends SafetyCrisisLevelDropRule(3, Not(ViewerDoesFollowAuthor)) + +object SafetyCrisisLevel4DropRule extends SafetyCrisisLevelDropRule(4) + +abstract class SafetyCrisisLevelSectionRule(level: Int) + extends ConditionWithNotInnerCircleOfFriendsRule( + ConversationSectionAbusiveQuality, + And( + TweetHasLabel(TweetSafetyLabelType.SafetyCrisis), + TweetHasSafetyLabelWithScoreEqInt(TweetSafetyLabelType.SafetyCrisis, level)) + ) { + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(TweetSafetyLabelType.SafetyCrisis)) +} + +object SafetyCrisisLevel3SectionRule + extends SafetyCrisisLevelSectionRule(3) + with DoesLogVerdictDecidered { + override def verdictLogDeciderKey = DeciderKey.EnableDownlevelRuleVerdictLogging +} + +object SafetyCrisisLevel4SectionRule + extends SafetyCrisisLevelSectionRule(4) + with DoesLogVerdictDecidered { + override def verdictLogDeciderKey = DeciderKey.EnableDownlevelRuleVerdictLogging +} + +object DoNotAmplifyDropRule + extends NonFollowerWithTweetLabelRule(Drop(Unspecified), TweetSafetyLabelType.DoNotAmplify) + +object DoNotAmplifyAllViewersDropRule + extends TweetHasLabelRule(Drop(Unspecified), TweetSafetyLabelType.DoNotAmplify) + +object DoNotAmplifySectionRule + extends ConditionWithNotInnerCircleOfFriendsRule( + ConversationSectionAbusiveQuality, + TweetHasLabel(TweetSafetyLabelType.DoNotAmplify)) + +object HighPSpammyScoreAllViewerDropRule + extends TweetHasLabelRule(Drop(Unspecified), TweetSafetyLabelType.HighPSpammyTweetScore) + +object HighPSpammyTweetScoreSearchTweetLabelDropRule + extends RuleWithConstantAction( + action = Drop(Unspecified), + condition = And( + LoggedOutOrViewerNotFollowingAuthor, + TweetHasLabelWithScoreAboveThreshold( + TweetSafetyLabelType.HighPSpammyTweetScore, + ModelScoreThresholds.HighPSpammyTweetScoreThreshold) + ) + ) + with DoesLogVerdictDecidered { + override def enabled: Seq[RuleParam[Boolean]] = Seq( + EnableHighPSpammyTweetScoreSearchTweetLabelDropRuleParam) + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(TweetSafetyLabelType.HighPSpammyTweetScore)) + override def verdictLogDeciderKey: DeciderKey.Value = + DeciderKey.EnableSpammyTweetRuleVerdictLogging +} + +object AdsManagerDenyListAllUsersTweetLabelRule + extends TweetHasLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.AdsManagerDenyList + ) + +abstract class SmyteSpamTweetLabelRule(action: Action) + extends NonAuthorWithTweetLabelRule( + action, + TweetSafetyLabelType.SmyteSpamTweet + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableSmyteSpamTweetRuleParam) +} + +object SmyteSpamTweetLabelDropRule extends SmyteSpamTweetLabelRule(Drop(TweetLabeledSpam)) + +object SmyteSpamTweetLabelTombstoneRule + extends SmyteSpamTweetLabelRule(Tombstone(Epitaph.Unavailable)) + +object SmyteSpamTweetLabelDropSearchRule extends SmyteSpamTweetLabelRule(Drop(Unspecified)) + +object HighSpammyTweetContentScoreSearchLatestTweetLabelDropRule + extends RuleWithConstantAction( + action = Drop(Unspecified), + condition = And( + Not(IsTweetInTweetLevelStcmHoldback), + LoggedOutOrViewerNotFollowingAuthor, + TweetHasLabelWithScoreAboveThresholdWithParam( + TweetSafetyLabelType.HighSpammyTweetContentScore, + HighSpammyTweetContentScoreSearchLatestProdTweetLabelDropRuleThresholdParam) + ) + ) + with DoesLogVerdictDecidered { + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(TweetSafetyLabelType.HighSpammyTweetContentScore)) + override def verdictLogDeciderKey: DeciderKey.Value = + DeciderKey.EnableSpammyTweetRuleVerdictLogging +} + +object HighSpammyTweetContentScoreSearchTopTweetLabelDropRule + extends RuleWithConstantAction( + action = Drop(Unspecified), + condition = And( + Not(IsTweetInTweetLevelStcmHoldback), + LoggedOutOrViewerNotFollowingAuthor, + TweetHasLabelWithScoreAboveThresholdWithParam( + TweetSafetyLabelType.HighSpammyTweetContentScore, + HighSpammyTweetContentScoreSearchTopProdTweetLabelDropRuleThresholdParam) + ) + ) + with DoesLogVerdictDecidered { + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(TweetSafetyLabelType.HighSpammyTweetContentScore)) + override def verdictLogDeciderKey: DeciderKey.Value = + DeciderKey.EnableSpammyTweetRuleVerdictLogging + +} + +object HighSpammyTweetContentScoreTrendsTopTweetLabelDropRule + extends RuleWithConstantAction( + action = Drop(Unspecified), + condition = And( + Not(IsTweetInTweetLevelStcmHoldback), + LoggedOutOrViewerNotFollowingAuthor, + IsTrendClickSourceSearchResult, + TweetHasLabelWithScoreAboveThresholdWithParam( + TweetSafetyLabelType.HighSpammyTweetContentScore, + HighSpammyTweetContentScoreTrendTopTweetLabelDropRuleThresholdParam) + ) + ) + with DoesLogVerdictDecidered { + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(TweetSafetyLabelType.HighSpammyTweetContentScore)) + override def verdictLogDeciderKey: DeciderKey.Value = + DeciderKey.EnableSpammyTweetRuleVerdictLogging + +} + +object HighSpammyTweetContentScoreTrendsLatestTweetLabelDropRule + extends RuleWithConstantAction( + action = Drop(Unspecified), + condition = And( + Not(IsTweetInTweetLevelStcmHoldback), + LoggedOutOrViewerNotFollowingAuthor, + IsTrendClickSourceSearchResult, + TweetHasLabelWithScoreAboveThresholdWithParam( + TweetSafetyLabelType.HighSpammyTweetContentScore, + HighSpammyTweetContentScoreTrendLatestTweetLabelDropRuleThresholdParam) + ) + ) + with DoesLogVerdictDecidered { + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(TweetSafetyLabelType.HighSpammyTweetContentScore)) + override def verdictLogDeciderKey: DeciderKey.Value = + DeciderKey.EnableSpammyTweetRuleVerdictLogging +} + +object GoreAndViolenceTopicHighRecallTweetLabelRule + extends NonAuthorWithTweetLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.GoreAndViolenceTopicHighRecall + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq( + EnableGoreAndViolenceTopicHighRecallTweetLabelRule) +} + +object CopypastaSpamAllViewersTweetLabelRule + extends TweetHasLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.CopypastaSpam + ) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableCopypastaSpamSearchDropRule) +} + +object CopypastaSpamAllViewersSearchTweetLabelRule + extends TweetHasLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.CopypastaSpam + ) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableCopypastaSpamSearchDropRule) +} + +object CopypastaSpamNonFollowerSearchTweetLabelRule + extends NonFollowerWithTweetLabelRule( + Drop(Unspecified), + TweetSafetyLabelType.CopypastaSpam + ) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableCopypastaSpamSearchDropRule) +} + +object CopypastaSpamAbusiveQualityTweetLabelRule + extends ConditionWithNotInnerCircleOfFriendsRule( + ConversationSectionAbusiveQuality, + TweetHasLabel(TweetSafetyLabelType.CopypastaSpam) + ) + with DoesLogVerdictDecidered { + override def enabled: Seq[RuleParam[Boolean]] = Seq( + EnableCopypastaSpamDownrankConvosAbusiveQualityRule) + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(TweetSafetyLabelType.CopypastaSpam)) + override def verdictLogDeciderKey = DeciderKey.EnableDownlevelRuleVerdictLogging +} + +object DynamicProductAdLimitedEngagementTweetLabelRule + extends TweetHasLabelRule( + LimitedEngagements(LimitedEngagementReason.DynamicProductAd), + TweetSafetyLabelType.DynamicProductAd) + +object SkipTweetDetailLimitedEngagementTweetLabelRule + extends AlwaysActRule(LimitedEngagements(LimitedEngagementReason.SkipTweetDetail)) { + override def enabled: Seq[RuleParam[Boolean]] = Seq( + SkipTweetDetailLimitedEngagementRuleEnabledParam) +} + +object DynamicProductAdDropTweetLabelRule + extends TweetHasLabelRule(Drop(Unspecified), TweetSafetyLabelType.DynamicProductAd) + +object NsfwTextTweetLabelTopicsDropRule + extends RuleWithConstantAction( + Drop(Reason.Nsfw), + And( + NonAuthorViewer, + Or( + TweetHasLabel(TweetSafetyLabelType.ExperimentalSensitiveIllegal2), + TweetHasLabel(TweetSafetyLabelType.NsfwTextHighPrecision) + ) + ) + ) + with DoesLogVerdict { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableNsfwTextTopicsDropRuleParam) + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(TweetSafetyLabelType.NsfwTextHighPrecision)) +} + + +object ExperimentalNudgeLabelRule + extends TweetHasLabelRule( + TweetVisibilityNudge(TweetVisibilityNudgeReason.ExperimentalNudgeSafetyLabelReason), + TweetSafetyLabelType.ExperimentalNudge) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableExperimentalNudgeEnabledParam) +} + +object NsfwTextTweetLabelAvoidRule + extends RuleWithConstantAction( + Avoid(), + Or( + TweetHasLabel(TweetSafetyLabelType.ExperimentalSensitiveIllegal2), + TweetHasLabel(TweetSafetyLabelType.NsfwTextHighPrecision) + ) + ) { + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(TweetSafetyLabelType.NsfwTextHighPrecision)) +} + +object DoNotAmplifyTweetLabelAvoidRule + extends TweetHasLabelRule( + Avoid(), + TweetSafetyLabelType.DoNotAmplify + ) + +object NsfaHighPrecisionTweetLabelAvoidRule + extends TweetHasLabelRule( + Avoid(), + TweetSafetyLabelType.NsfaHighPrecision + ) { + override val fallbackActionBuilder: Option[ActionBuilder[_ <: Action]] = Some( + new ConstantActionBuilder(Avoid(Some(MightNotBeSuitableForAds)))) +} + +object NsfwHighPrecisionTweetLabelAvoidRule + extends TweetHasLabelRule( + Avoid(Some(AvoidReason.ContainsNsfwMedia)), + TweetSafetyLabelType.NsfwHighPrecision + ) { + override val fallbackActionBuilder: Option[ActionBuilder[_ <: Action]] = Some( + new ConstantActionBuilder(Avoid(Some(MightNotBeSuitableForAds)))) +} + +object NsfwHighRecallTweetLabelAvoidRule + extends TweetHasLabelRule( + Avoid(Some(AvoidReason.ContainsNsfwMedia)), + TweetSafetyLabelType.NsfwHighRecall + ) { + override val fallbackActionBuilder: Option[ActionBuilder[_ <: Action]] = Some( + new ConstantActionBuilder(Avoid(Some(MightNotBeSuitableForAds)))) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/TweetRules.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/TweetRules.scala new file mode 100644 index 000000000..666b2debf --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/TweetRules.scala @@ -0,0 +1,594 @@ +package com.twitter.visibility.rules + +import com.twitter.visibility.common.actions.LimitedEngagementReason +import com.twitter.visibility.configapi.params.FSRuleParams.AdAvoidanceHighToxicityModelScoreThresholdParam +import com.twitter.visibility.configapi.params.FSRuleParams.AdAvoidanceReportedTweetModelScoreThresholdParam +import com.twitter.visibility.configapi.params.FSRuleParams.CommunityTweetCommunityUnavailableLimitedActionsRulesEnabledParam +import com.twitter.visibility.configapi.params.FSRuleParams.CommunityTweetDropProtectedRuleEnabledParam +import com.twitter.visibility.configapi.params.FSRuleParams.CommunityTweetDropRuleEnabledParam +import com.twitter.visibility.configapi.params.FSRuleParams.CommunityTweetLimitedActionsRulesEnabledParam +import com.twitter.visibility.configapi.params.FSRuleParams.CommunityTweetMemberRemovedLimitedActionsRulesEnabledParam +import com.twitter.visibility.configapi.params.FSRuleParams.CommunityTweetNonMemberLimitedActionsRuleEnabledParam +import com.twitter.visibility.configapi.params.FSRuleParams.StaleTweetLimitedActionsRulesEnabledParam +import com.twitter.visibility.configapi.params.FSRuleParams.TrustedFriendsTweetLimitedEngagementsRuleEnabledParam +import com.twitter.visibility.configapi.params.RuleParam +import com.twitter.visibility.configapi.params.RuleParams +import com.twitter.visibility.configapi.params.RuleParams._ +import com.twitter.visibility.features.{TweetDeleteReason => FeatureTweetDeleteReason} +import com.twitter.visibility.models.TweetDeleteReason +import com.twitter.visibility.models.TweetSafetyLabelType +import com.twitter.visibility.rules.Condition.ViewerIsExclusiveTweetAuthor +import com.twitter.visibility.rules.Condition._ +import com.twitter.visibility.rules.Reason.CommunityTweetAuthorRemoved +import com.twitter.visibility.rules.Reason.CommunityTweetHidden +import com.twitter.visibility.rules.Reason.Nsfw +import com.twitter.visibility.rules.Reason.StaleTweet +import com.twitter.visibility.rules.Reason.Unspecified +import com.twitter.visibility.rules.RuleActionSourceBuilder.TweetSafetyLabelSourceBuilder + +abstract class TweetHasLabelRule(action: Action, tweetSafetyLabelType: TweetSafetyLabelType) + extends RuleWithConstantAction(action, TweetHasLabel(tweetSafetyLabelType)) { + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(tweetSafetyLabelType)) +} + +abstract class ConditionWithTweetLabelRule( + action: Action, + condition: Condition, + tweetSafetyLabelType: TweetSafetyLabelType) + extends RuleWithConstantAction(action, And(TweetHasLabel(tweetSafetyLabelType), condition)) { + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(tweetSafetyLabelType)) +} + +abstract class NonAuthorWithTweetLabelRule( + action: Action, + tweetSafetyLabelType: TweetSafetyLabelType) + extends ConditionWithTweetLabelRule(action, NonAuthorViewer, tweetSafetyLabelType) { + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + TweetSafetyLabelSourceBuilder(tweetSafetyLabelType)) +} + +abstract class NonFollowerWithTweetLabelRule( + action: Action, + tweetSafetyLabelType: TweetSafetyLabelType) + extends ConditionWithTweetLabelRule( + action, + LoggedOutOrViewerNotFollowingAuthor, + tweetSafetyLabelType + ) + +abstract class NonAuthorAndNonFollowerWithTweetLabelRule( + action: Action, + tweetSafetyLabelType: TweetSafetyLabelType) + extends ConditionWithTweetLabelRule( + action, + And(NonAuthorViewer, LoggedOutOrViewerNotFollowingAuthor), + tweetSafetyLabelType + ) + +abstract class NonFollowerWithUqfTweetLabelRule( + action: Action, + tweetSafetyLabelType: TweetSafetyLabelType) + extends ConditionWithTweetLabelRule( + action, + Or( + LoggedOutViewer, + And( + NonAuthorViewer, + Not(ViewerDoesFollowAuthor), + ViewerHasUqfEnabled + ) + ), + tweetSafetyLabelType + ) + +abstract class ViewerWithUqfTweetLabelRule(action: Action, labelValue: TweetSafetyLabelType) + extends ConditionWithTweetLabelRule(action, ViewerHasUqfEnabled, labelValue) + +case object ConversationControlRules { + + abstract class ConversationControlBaseRule(condition: Condition) + extends RuleWithConstantAction( + LimitedEngagements(LimitedEngagementReason.ConversationControl), + condition) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(TweetConversationControlEnabledParam) + } + + object LimitRepliesCommunityConversationRule + extends ConversationControlBaseRule( + And( + TweetIsCommunityConversation, + Not( + Or( + LoggedOutViewer, + Retweet, + ViewerIsTweetConversationRootAuthor, + ViewerIsInvitedToTweetConversation, + ConversationRootAuthorDoesFollowViewer + )) + ) + ) + + object LimitRepliesFollowersConversationRule + extends ConversationControlBaseRule( + And( + TweetIsFollowersConversation, + Not( + Or( + LoggedOutViewer, + Retweet, + ViewerIsTweetConversationRootAuthor, + ViewerIsInvitedToTweetConversation, + ViewerDoesFollowConversationRootAuthor + )) + ) + ) + + object LimitRepliesByInvitationConversationRule + extends ConversationControlBaseRule( + And( + TweetIsByInvitationConversation, + Not( + Or( + LoggedOutViewer, + Retweet, + ViewerIsTweetConversationRootAuthor, + ViewerIsInvitedToTweetConversation + )) + ) + ) + +} + +abstract class NonAuthorViewerOptInFilteringWithTweetLabelRule( + action: Action, + tweetSafetyLabelType: TweetSafetyLabelType) + extends ConditionWithTweetLabelRule( + action, + And(NonAuthorViewer, LoggedOutOrViewerOptInFiltering), + tweetSafetyLabelType) + +abstract class NonFollowerViewerOptInFilteringWithTweetLabelRule( + action: Action, + tweetSafetyLabelType: TweetSafetyLabelType) + extends ConditionWithTweetLabelRule( + action, + And(LoggedOutOrViewerNotFollowingAuthor, LoggedOutOrViewerOptInFiltering), + tweetSafetyLabelType + ) + +object TweetNsfwUserDropRule extends RuleWithConstantAction(Drop(Nsfw), TweetHasNsfwUserAuthor) +object TweetNsfwAdminDropRule extends RuleWithConstantAction(Drop(Nsfw), TweetHasNsfwAdminAuthor) + +object NullcastedTweetRule + extends RuleWithConstantAction( + Drop(Unspecified), + And(Nullcast, Not(Retweet), Not(IsQuotedInnerTweet), Not(TweetIsCommunityTweet))) + +object MutedRetweetsRule + extends RuleWithConstantAction(Drop(Unspecified), And(Retweet, ViewerMutesRetweetsFromAuthor)) + +abstract class FilterCommunityTweetsRule(override val action: Action) + extends RuleWithConstantAction(action, TweetIsCommunityTweet) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(CommunityTweetDropRuleEnabledParam) +} + +object DropCommunityTweetsRule extends FilterCommunityTweetsRule(Drop(CommunityTweetHidden)) + +object TombstoneCommunityTweetsRule + extends FilterCommunityTweetsRule(Tombstone(Epitaph.Unavailable)) + +abstract class FilterCommunityTweetCommunityNotVisibleRule(override val action: Action) + extends RuleWithConstantAction( + action, + And( + NonAuthorViewer, + TweetIsCommunityTweet, + Not(CommunityTweetCommunityVisible), + )) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(DropCommunityTweetWithUndefinedCommunityRuleEnabledParam) +} + +object DropCommunityTweetCommunityNotVisibleRule + extends FilterCommunityTweetCommunityNotVisibleRule(Drop(CommunityTweetHidden)) + +object TombstoneCommunityTweetCommunityNotVisibleRule + extends FilterCommunityTweetCommunityNotVisibleRule(Tombstone(Epitaph.Unavailable)) + +abstract class FilterAllCommunityTweetsRule(override val action: Action) + extends RuleWithConstantAction(action, TweetIsCommunityTweet) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(CommunityTweetsEnabledParam) +} + +object DropAllCommunityTweetsRule extends FilterAllCommunityTweetsRule(Drop(Unspecified)) + +object TombstoneAllCommunityTweetsRule + extends FilterAllCommunityTweetsRule(Tombstone(Epitaph.Unavailable)) + +object DropOuterCommunityTweetsRule + extends RuleWithConstantAction( + Drop(Unspecified), + And(TweetIsCommunityTweet, Not(IsQuotedInnerTweet))) + +object DropAllHiddenCommunityTweetsRule + extends RuleWithConstantAction( + Drop(Unspecified), + And(TweetIsCommunityTweet, CommunityTweetIsHidden)) + +abstract class FilterHiddenCommunityTweetsRule(override val action: Action) + extends RuleWithConstantAction( + action, + And( + NonAuthorViewer, + TweetIsCommunityTweet, + CommunityTweetIsHidden, + Not(ViewerIsCommunityModerator) + )) + +object DropHiddenCommunityTweetsRule + extends FilterHiddenCommunityTweetsRule(Drop(CommunityTweetHidden)) + +object TombstoneHiddenCommunityTweetsRule + extends FilterHiddenCommunityTweetsRule(Tombstone(Epitaph.CommunityTweetHidden)) + +object DropAllAuthorRemovedCommunityTweetsRule + extends RuleWithConstantAction( + Drop(Unspecified), + And(TweetIsCommunityTweet, CommunityTweetAuthorIsRemoved)) + +abstract class FilterAuthorRemovedCommunityTweetsRule(override val action: Action) + extends RuleWithConstantAction( + action, + And( + NonAuthorViewer, + TweetIsCommunityTweet, + CommunityTweetAuthorIsRemoved, + Not(ViewerIsCommunityModerator) + )) + +object DropAuthorRemovedCommunityTweetsRule + extends FilterAuthorRemovedCommunityTweetsRule(Drop(CommunityTweetAuthorRemoved)) + +object TombstoneAuthorRemovedCommunityTweetsRule + extends FilterAuthorRemovedCommunityTweetsRule(Tombstone(Epitaph.Unavailable)) + +abstract class FilterProtectedCommunityTweetsRule(override val action: Action) + extends RuleWithConstantAction( + action, + And(TweetIsCommunityTweet, ProtectedAuthor, NonAuthorViewer)) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(CommunityTweetDropProtectedRuleEnabledParam) +} + +object DropProtectedCommunityTweetsRule + extends FilterProtectedCommunityTweetsRule(Drop(CommunityTweetHidden)) + +object TombstoneProtectedCommunityTweetsRule + extends FilterProtectedCommunityTweetsRule(Tombstone(Epitaph.Unavailable)) + +abstract class CommunityTweetCommunityUnavailableLimitedActionsRule( + reason: LimitedEngagementReason, + condition: CommunityTweetCommunityUnavailable, +) extends RuleWithConstantAction( + LimitedEngagements(reason), + And( + Not(NonAuthorViewer), + TweetIsCommunityTweet, + condition, + ) + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq( + CommunityTweetCommunityUnavailableLimitedActionsRulesEnabledParam) +} + +object CommunityTweetCommunityNotFoundLimitedActionsRule + extends CommunityTweetCommunityUnavailableLimitedActionsRule( + LimitedEngagementReason.CommunityTweetCommunityNotFound, + CommunityTweetCommunityNotFound, + ) + +object CommunityTweetCommunityDeletedLimitedActionsRule + extends CommunityTweetCommunityUnavailableLimitedActionsRule( + LimitedEngagementReason.CommunityTweetCommunityDeleted, + CommunityTweetCommunityDeleted, + ) + +object CommunityTweetCommunitySuspendedLimitedActionsRule + extends CommunityTweetCommunityUnavailableLimitedActionsRule( + LimitedEngagementReason.CommunityTweetCommunitySuspended, + CommunityTweetCommunitySuspended, + ) + +abstract class CommunityTweetModeratedLimitedActionsRule( + reason: LimitedEngagementReason, + condition: CommunityTweetIsModerated, + enabledParam: RuleParam[Boolean], +) extends RuleWithConstantAction( + LimitedEngagements(reason), + And( + TweetIsCommunityTweet, + condition, + Or( + Not(NonAuthorViewer), + ViewerIsCommunityModerator, + ) + )) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(enabledParam) +} + +object CommunityTweetMemberRemovedLimitedActionsRule + extends CommunityTweetModeratedLimitedActionsRule( + LimitedEngagementReason.CommunityTweetMemberRemoved, + CommunityTweetAuthorIsRemoved, + CommunityTweetMemberRemovedLimitedActionsRulesEnabledParam, + ) + +object CommunityTweetHiddenLimitedActionsRule + extends CommunityTweetModeratedLimitedActionsRule( + LimitedEngagementReason.CommunityTweetHidden, + CommunityTweetIsHidden, + CommunityTweetLimitedActionsRulesEnabledParam, + ) + +abstract class CommunityTweetLimitedActionsRule( + reason: LimitedEngagementReason, + condition: Condition, +) extends RuleWithConstantAction( + LimitedEngagements(reason), + And( + TweetIsCommunityTweet, + condition + )) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(CommunityTweetLimitedActionsRulesEnabledParam) +} + +object CommunityTweetMemberLimitedActionsRule + extends CommunityTweetLimitedActionsRule( + LimitedEngagementReason.CommunityTweetMember, + ViewerIsCommunityMember, + ) + +object CommunityTweetNonMemberLimitedActionsRule + extends CommunityTweetLimitedActionsRule( + LimitedEngagementReason.CommunityTweetNonMember, + Not(ViewerIsCommunityMember), + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq( + CommunityTweetNonMemberLimitedActionsRuleEnabledParam) +} + +object ReportedTweetInterstitialRule + extends RuleWithConstantAction( + Interstitial(Reason.ViewerReportedTweet), + And( + NonAuthorViewer, + Not(Retweet), + ViewerReportsTweet + )) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableReportedTweetInterstitialRule) +} + +object ReportedTweetInterstitialSearchRule + extends RuleWithConstantAction( + Interstitial(Reason.ViewerReportedTweet), + And( + NonAuthorViewer, + Not(Retweet), + ViewerReportsTweet + )) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableReportedTweetInterstitialSearchRule) +} + +abstract class FilterExclusiveTweetContentRule( + action: Action, + additionalCondition: Condition = Condition.True) + extends RuleWithConstantAction( + action, + And( + additionalCondition, + TweetIsExclusiveContent, + Or( + LoggedOutViewer, + Not( + Or( + ViewerIsExclusiveTweetAuthor, + ViewerSuperFollowsExclusiveTweetAuthor, + And( + Not(NonAuthorViewer), + Not(Retweet) + ) + ) + ), + ), + ) + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableDropExclusiveTweetContentRule) + override def enableFailClosed: Seq[RuleParam[Boolean]] = Seq( + EnableDropExclusiveTweetContentRuleFailClosed) +} + +object DropExclusiveTweetContentRule + extends FilterExclusiveTweetContentRule(Drop(Reason.ExclusiveTweet)) + +object TombstoneExclusiveTweetContentRule + extends FilterExclusiveTweetContentRule(Tombstone(Epitaph.SuperFollowsContent)) + +object TombstoneExclusiveQuotedTweetContentRule + extends FilterExclusiveTweetContentRule( + Tombstone(Epitaph.SuperFollowsContent), + IsQuotedInnerTweet + ) + +object DropAllExclusiveTweetsRule + extends RuleWithConstantAction( + Drop(Reason.ExclusiveTweet), + TweetIsExclusiveContent + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableDropAllExclusiveTweetsRuleParam) + override def enableFailClosed: Seq[RuleParam[Boolean]] = Seq( + EnableDropAllExclusiveTweetsRuleFailClosedParam) +} + +object DropTweetsWithGeoRestrictedMediaRule + extends RuleWithConstantAction(Drop(Unspecified), MediaRestrictedInViewerCountry) { + override def enabled: Seq[RuleParam[Boolean]] = Seq( + EnableDropTweetsWithGeoRestrictedMediaRuleParam) +} + +object TrustedFriendsTweetLimitedEngagementsRule + extends RuleWithConstantAction( + LimitedEngagements(LimitedEngagementReason.TrustedFriendsTweet), + TweetIsTrustedFriendsContent + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq( + TrustedFriendsTweetLimitedEngagementsRuleEnabledParam + ) +} + +object DropAllTrustedFriendsTweetsRule + extends RuleWithConstantAction( + Drop(Reason.TrustedFriendsTweet), + TweetIsTrustedFriendsContent + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableDropAllTrustedFriendsTweetsRuleParam) + override def enableFailClosed: Seq[RuleParam[Boolean]] = Seq(RuleParams.True) +} + +object DropAllCollabInvitationTweetsRule + extends RuleWithConstantAction( + Drop(Unspecified), + TweetIsCollabInvitationContent + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableDropAllCollabInvitationTweetsRuleParam) + override def enableFailClosed: Seq[RuleParam[Boolean]] = Seq(RuleParams.True) +} + +abstract class FilterTrustedFriendsTweetContentRule(action: Action) + extends OnlyWhenNotAuthorViewerRule( + action, + And( + TweetIsTrustedFriendsContent, + Not( + Or( + ViewerIsTrustedFriendsTweetAuthor, + ViewerIsTrustedFriend + ) + ) + ) + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableDropTrustedFriendsTweetContentRuleParam) + override def enableFailClosed: Seq[RuleParam[Boolean]] = Seq(RuleParams.True) +} + +object DropTrustedFriendsTweetContentRule + extends FilterTrustedFriendsTweetContentRule(Drop(Reason.TrustedFriendsTweet)) + +object TombstoneTrustedFriendsTweetContentRule + extends FilterTrustedFriendsTweetContentRule(Tombstone(Epitaph.Unavailable)) + +object TweetNsfwUserAdminAvoidRule + extends RuleWithConstantAction( + Avoid(), + Or( + TweetHasNsfwUserAuthor, + TweetHasNsfwAdminAuthor, + NsfwUserAuthor, + NsfwAdminAuthor + ) + ) + +object AvoidHighToxicityModelScoreRule + extends RuleWithConstantAction( + Avoid(), + TweetHasLabelWithScoreAboveThresholdWithParam( + TweetSafetyLabelType.HighToxicityScore, + AdAvoidanceHighToxicityModelScoreThresholdParam) + ) + +object AvoidReportedTweetModelScoreRule + extends RuleWithConstantAction( + Avoid(), + TweetHasLabelWithScoreAboveThresholdWithParam( + TweetSafetyLabelType.HighPReportedTweetScore, + AdAvoidanceReportedTweetModelScoreThresholdParam) + ) + +object TombstoneDeletedOuterTweetRule + extends RuleWithConstantAction( + Tombstone(Epitaph.Deleted), + And( + Equals(FeatureTweetDeleteReason, TweetDeleteReason.Deleted), + Not(IsQuotedInnerTweet) + ) + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableDeleteStateTweetRulesParam) +} + +object TombstoneDeletedTweetRule + extends RuleWithConstantAction( + Tombstone(Epitaph.Deleted), + And( + Equals(FeatureTweetDeleteReason, TweetDeleteReason.Deleted), + ) + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableDeleteStateTweetRulesParam) +} + +object TombstoneDeletedQuotedTweetRule + extends RuleWithConstantAction( + Tombstone(Epitaph.Deleted), + And( + Equals(FeatureTweetDeleteReason, TweetDeleteReason.Deleted), + IsQuotedInnerTweet + ) + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableDeleteStateTweetRulesParam) +} + +object TombstoneBounceDeletedTweetRule + extends RuleWithConstantAction( + Tombstone(Epitaph.BounceDeleted), + Equals(FeatureTweetDeleteReason, TweetDeleteReason.BounceDeleted), + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableDeleteStateTweetRulesParam) +} + +object TombstoneBounceDeletedOuterTweetRule + extends RuleWithConstantAction( + Tombstone(Epitaph.BounceDeleted), + And( + Equals(FeatureTweetDeleteReason, TweetDeleteReason.BounceDeleted), + Not(IsQuotedInnerTweet) + ) + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableDeleteStateTweetRulesParam) +} + +object TombstoneBounceDeletedQuotedTweetRule + extends RuleWithConstantAction( + Tombstone(Epitaph.BounceDeleted), + And( + Equals(FeatureTweetDeleteReason, TweetDeleteReason.BounceDeleted), + IsQuotedInnerTweet + ) + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableDeleteStateTweetRulesParam) +} + + +object DropStaleTweetsRule + extends RuleWithConstantAction( + Drop(StaleTweet), + And(TweetIsStaleTweet, Not(IsQuotedInnerTweet), Not(Retweet), Not(IsSourceTweet))) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableStaleTweetDropRuleParam) + override def enableFailClosed: Seq[RuleParam[Boolean]] = Seq( + EnableStaleTweetDropRuleFailClosedParam) +} + +object StaleTweetLimitedActionsRule + extends RuleWithConstantAction( + LimitedEngagements(LimitedEngagementReason.StaleTweet), + TweetIsStaleTweet) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(StaleTweetLimitedActionsRulesEnabledParam) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/UserLabelRules.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/UserLabelRules.scala new file mode 100644 index 000000000..668b28538 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/UserLabelRules.scala @@ -0,0 +1,361 @@ +package com.twitter.visibility.rules + +import com.twitter.abdecider.LoggingABDecider +import com.twitter.timelines.configapi.Params +import com.twitter.visibility.configapi.configs.DeciderKey +import com.twitter.visibility.configapi.params.RuleParam +import com.twitter.visibility.configapi.params.RuleParams +import com.twitter.visibility.configapi.params.RuleParams._ +import com.twitter.visibility.features.Feature +import com.twitter.visibility.models.UserLabelValue +import com.twitter.visibility.models.UserLabelValue._ +import com.twitter.visibility.rules.Condition._ +import com.twitter.visibility.rules.Reason._ +import com.twitter.visibility.rules.RuleActionSourceBuilder.UserSafetyLabelSourceBuilder + +object AbusiveRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + Abusive + ) + +object DoNotAmplifyUserRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + DoNotAmplify + ) + +object AbusiveHighRecallRule + extends AuthorLabelAndNonFollowerViewerRule( + Drop(Unspecified), + AbusiveHighRecall + ) + +object CompromisedRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + Compromised + ) + +object DuplicateContentRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + DuplicateContent + ) + +object EngagementSpammerRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + EngagementSpammer + ) + +object EngagementSpammerHighRecallRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + EngagementSpammerHighRecall + ) + +object LiveLowQualityRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + LiveLowQuality + ) + +object LowQualityRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + LowQuality + ) + +object LowQualityHighRecallRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + LowQualityHighRecall + ) + +object NotGraduatedRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + NotGraduated + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableNotGraduatedDropRuleParam) + override def holdbacks: Seq[RuleParam[Boolean]] = Seq( + NotGraduatedUserLabelRuleHoldbackExperimentParam) + +} + +abstract class BaseNsfwHighPrecisionRule() + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + UserLabelValue.NsfwHighPrecision + ) +object NsfwHighPrecisionRule + extends BaseNsfwHighPrecisionRule() + +object NsfwHighRecallRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + NsfwHighRecall + ) + +abstract class BaseNsfwNearPerfectAuthorRule() + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + NsfwNearPerfect + ) +object NsfwNearPerfectAuthorRule extends BaseNsfwNearPerfectAuthorRule() + +object NsfwAvatarImageRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + NsfwAvatarImage + ) + +object NsfwBannerImageRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + NsfwBannerImage + ) + +object NsfwSensitiveRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + NsfwSensitive + ) + +object ReadOnlyRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + ReadOnly + ) + +object RecommendationsBlacklistRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + RecommendationsBlacklist + ) + +sealed abstract class BaseSpamHighRecallRule(val holdback: RuleParam[Boolean]) + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + SpamHighRecall + ) { + override val holdbacks: Seq[RuleParam[Boolean]] = Seq(holdback) +} + +object SpamHighRecallRule extends BaseSpamHighRecallRule(RuleParams.False) + +object DeciderableSpamHighRecallRule extends BaseSpamHighRecallRule(RuleParams.False) + +object SearchBlacklistRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + SearchBlacklist + ) + +object SearchNsfwTextRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + NsfwText + ) { + + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableNsfwTextSectioningRuleParam) +} + +object SpammyFollowerRule + extends OnlyWhenNotAuthorViewerRule( + Drop(Unspecified), + And( + Or( + AuthorHasLabel(Compromised), + AuthorHasLabel(EngagementSpammer), + AuthorHasLabel(EngagementSpammerHighRecall), + AuthorHasLabel(LowQuality), + AuthorHasLabel(ReadOnly), + AuthorHasLabel(SpamHighRecall) + ), + Or( + LoggedOutViewer, + And( + NonAuthorViewer, + ViewerHasUqfEnabled, + Or( + And( + ProtectedViewer, + LoggedOutOrViewerNotFollowingAuthor, + Not(AuthorDoesFollowViewer) + ), + And(Not(ProtectedViewer), LoggedOutOrViewerNotFollowingAuthor) + ) + ) + ) + ) + ) + +abstract class NonFollowerWithUqfUserLabelDropRule(labelValue: UserLabelValue) + extends ConditionWithUserLabelRule( + Drop(Unspecified), + And( + Or( + LoggedOutViewer, + And(Not(ViewerDoesFollowAuthor), ViewerHasUqfEnabled) + ) + ), + labelValue + ) + +object EngagementSpammerNonFollowerWithUqfRule + extends NonFollowerWithUqfUserLabelDropRule( + EngagementSpammer + ) + +object EngagementSpammerHighRecallNonFollowerWithUqfRule + extends NonFollowerWithUqfUserLabelDropRule( + EngagementSpammerHighRecall + ) + +object SpamHighRecallNonFollowerWithUqfRule + extends NonFollowerWithUqfUserLabelDropRule( + SpamHighRecall + ) + +object CompromisedNonFollowerWithUqfRule + extends NonFollowerWithUqfUserLabelDropRule( + Compromised + ) + +object ReadOnlyNonFollowerWithUqfRule + extends NonFollowerWithUqfUserLabelDropRule( + ReadOnly + ) + +object LowQualityNonFollowerWithUqfRule + extends NonFollowerWithUqfUserLabelDropRule( + LowQuality + ) + +object TsViolationRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + TsViolation + ) + +object DownrankSpamReplyAllViewersRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + DownrankSpamReply + ) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableDownrankSpamReplySectioningRuleParam) +} + +object DownrankSpamReplyNonAuthorRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + DownrankSpamReply + ) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableDownrankSpamReplySectioningRuleParam) +} + +object DownrankSpamReplyNonFollowerWithUqfRule + extends NonFollowerWithUqfUserLabelDropRule(DownrankSpamReply) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableDownrankSpamReplySectioningRuleParam) +} + +object NsfwTextAllUsersDropRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + NsfwText + ) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableNsfwTextSectioningRuleParam) +} + +object NsfwTextNonAuthorDropRule + extends WhenAuthorUserLabelPresentRule( + Drop(Unspecified), + DownrankSpamReply + ) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableNsfwTextSectioningRuleParam) +} + +abstract class DeciderableSpamHighRecallAuthorLabelRule(action: Action) + extends RuleWithConstantAction( + action, + And( + NonAuthorViewer, + SelfReply, + AuthorHasLabel(SpamHighRecall, shortCircuitable = false) + ) + ) { + override def preFilter( + evaluationContext: EvaluationContext, + featureMap: Map[Feature[_], Any], + abDecider: LoggingABDecider + ): PreFilterResult = { + Filtered + } +} + +object DeciderableSpamHighRecallAuthorLabelDropRule + extends DeciderableSpamHighRecallAuthorLabelRule(Drop(Unspecified)) + +object DeciderableSpamHighRecallAuthorLabelTombstoneRule + extends DeciderableSpamHighRecallAuthorLabelRule(Tombstone(Epitaph.Unavailable)) + +object DoNotAmplifyNonFollowerRule + extends AuthorLabelAndNonFollowerViewerRule( + Drop(Unspecified), + DoNotAmplify + ) + +object NotGraduatedNonFollowerRule + extends AuthorLabelAndNonFollowerViewerRule( + Drop(Unspecified), + NotGraduated + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq(EnableNotGraduatedDropRuleParam) + override def holdbacks: Seq[RuleParam[Boolean]] = Seq( + NotGraduatedUserLabelRuleHoldbackExperimentParam) + +} + +object DoNotAmplifySectionUserRule + extends AuthorLabelWithNotInnerCircleOfFriendsRule( + ConversationSectionAbusiveQuality, + DoNotAmplify) + with DoesLogVerdictDecidered { + override def actionSourceBuilder: Option[RuleActionSourceBuilder] = Some( + UserSafetyLabelSourceBuilder(DoNotAmplify)) + override def verdictLogDeciderKey = DeciderKey.EnableDownlevelRuleVerdictLogging +} + + +object SpammyUserModelHighPrecisionDropTweetRule + extends AuthorLabelAndNonFollowerViewerRule( + Drop(Unspecified), + SpammyUserModelHighPrecision, + ) + with DoesLogVerdictDecidered { + override def isEnabled(params: Params): Boolean = + params(EnableSpammyUserModelTweetDropRuleParam) + override def verdictLogDeciderKey: DeciderKey.Value = + DeciderKey.EnableSpammyTweetRuleVerdictLogging +} + +object LikelyIvsLabelNonFollowerDropUserRule extends LikelyIvsLabelNonFollowerDropRule + +object SearchLikelyIvsLabelNonFollowerDropUserRule extends LikelyIvsLabelNonFollowerDropRule + +object NsfwHighPrecisionUserLabelAvoidTweetRule + extends UserHasLabelRule( + Avoid(), + UserLabelValue.NsfwHighPrecision + ) { + override def enabled: Seq[RuleParam[Boolean]] = Seq( + NsfwHighPrecisionUserLabelAvoidTweetRuleEnabledParam) +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/UserUnavailableStateTombstoneRules.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/UserUnavailableStateTombstoneRules.scala new file mode 100644 index 000000000..716ea6ab8 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/UserUnavailableStateTombstoneRules.scala @@ -0,0 +1,120 @@ +package com.twitter.visibility.rules + +import com.twitter.visibility.configapi.params.RuleParam +import com.twitter.visibility.configapi.params.RuleParams.EnableInnerQuotedTweetViewerBlocksAuthorInterstitialRuleParam +import com.twitter.visibility.configapi.params.RuleParams.EnableInnerQuotedTweetViewerMutesAuthorInterstitialRuleParam +import com.twitter.visibility.rules.Condition.And +import com.twitter.visibility.rules.Condition.AuthorBlocksViewer +import com.twitter.visibility.rules.Condition.DeactivatedAuthor +import com.twitter.visibility.rules.Condition.ErasedAuthor +import com.twitter.visibility.rules.Condition.IsQuotedInnerTweet +import com.twitter.visibility.rules.Condition.OffboardedAuthor +import com.twitter.visibility.rules.Condition.ProtectedAuthor +import com.twitter.visibility.rules.Condition.Retweet +import com.twitter.visibility.rules.Condition.SuspendedAuthor +import com.twitter.visibility.rules.Condition.UnavailableAuthor +import com.twitter.visibility.rules.Condition.ViewerBlocksAuthor +import com.twitter.visibility.rules.Condition.ViewerMutesAuthor + +object UserUnavailableStateTombstoneRules { + abstract class UserUnavailableStateTweetTombstoneRule(epitaph: Epitaph, condition: Condition) + extends RuleWithConstantAction(Tombstone(epitaph), condition) {} + + abstract class UserUnavailableStateRetweetTombstoneRule(epitaph: Epitaph, condition: Condition) + extends RuleWithConstantAction(Tombstone(epitaph), And(Retweet, condition)) {} + + abstract class UserUnavailableStateInnerQuotedTweetTombstoneRule( + epitaph: Epitaph, + condition: Condition) + extends RuleWithConstantAction(Tombstone(epitaph), And(IsQuotedInnerTweet, condition)) + + abstract class UserUnavailableStateInnerQuotedTweetInterstitialRule( + reason: Reason, + condition: Condition) + extends RuleWithConstantAction(Interstitial(reason), And(IsQuotedInnerTweet, condition)) + + object SuspendedUserUnavailableTweetTombstoneRule + extends UserUnavailableStateTweetTombstoneRule(Epitaph.Suspended, SuspendedAuthor) + + object DeactivatedUserUnavailableTweetTombstoneRule + extends UserUnavailableStateTweetTombstoneRule(Epitaph.Deactivated, DeactivatedAuthor) + + object OffBoardedUserUnavailableTweetTombstoneRule + extends UserUnavailableStateTweetTombstoneRule(Epitaph.Offboarded, OffboardedAuthor) + + object ErasedUserUnavailableTweetTombstoneRule + extends UserUnavailableStateTweetTombstoneRule(Epitaph.Deactivated, ErasedAuthor) + + object ProtectedUserUnavailableTweetTombstoneRule + extends UserUnavailableStateTweetTombstoneRule(Epitaph.Protected, ProtectedAuthor) + + object AuthorBlocksViewerUserUnavailableTweetTombstoneRule + extends UserUnavailableStateTweetTombstoneRule(Epitaph.BlockedBy, AuthorBlocksViewer) + + object UserUnavailableTweetTombstoneRule + extends UserUnavailableStateTweetTombstoneRule(Epitaph.Unavailable, UnavailableAuthor) + + object SuspendedUserUnavailableRetweetTombstoneRule + extends UserUnavailableStateRetweetTombstoneRule(Epitaph.Suspended, SuspendedAuthor) + + object DeactivatedUserUnavailableRetweetTombstoneRule + extends UserUnavailableStateRetweetTombstoneRule(Epitaph.Deactivated, DeactivatedAuthor) + + object OffBoardedUserUnavailableRetweetTombstoneRule + extends UserUnavailableStateRetweetTombstoneRule(Epitaph.Offboarded, OffboardedAuthor) + + object ErasedUserUnavailableRetweetTombstoneRule + extends UserUnavailableStateRetweetTombstoneRule(Epitaph.Deactivated, ErasedAuthor) + + object ProtectedUserUnavailableRetweetTombstoneRule + extends UserUnavailableStateRetweetTombstoneRule(Epitaph.Protected, ProtectedAuthor) + + object AuthorBlocksViewerUserUnavailableRetweetTombstoneRule + extends UserUnavailableStateRetweetTombstoneRule(Epitaph.BlockedBy, AuthorBlocksViewer) + + object ViewerBlocksAuthorUserUnavailableRetweetTombstoneRule + extends UserUnavailableStateRetweetTombstoneRule(Epitaph.Unavailable, ViewerBlocksAuthor) + + object ViewerMutesAuthorUserUnavailableRetweetTombstoneRule + extends UserUnavailableStateRetweetTombstoneRule(Epitaph.Unavailable, ViewerMutesAuthor) + + object SuspendedUserUnavailableInnerQuotedTweetTombstoneRule + extends UserUnavailableStateInnerQuotedTweetTombstoneRule(Epitaph.Suspended, SuspendedAuthor) + + object DeactivatedUserUnavailableInnerQuotedTweetTombstoneRule + extends UserUnavailableStateInnerQuotedTweetTombstoneRule( + Epitaph.Deactivated, + DeactivatedAuthor) + + object OffBoardedUserUnavailableInnerQuotedTweetTombstoneRule + extends UserUnavailableStateInnerQuotedTweetTombstoneRule( + Epitaph.Offboarded, + OffboardedAuthor) + + object ErasedUserUnavailableInnerQuotedTweetTombstoneRule + extends UserUnavailableStateInnerQuotedTweetTombstoneRule(Epitaph.Deactivated, ErasedAuthor) + + object ProtectedUserUnavailableInnerQuotedTweetTombstoneRule + extends UserUnavailableStateInnerQuotedTweetTombstoneRule(Epitaph.Protected, ProtectedAuthor) + + object AuthorBlocksViewerUserUnavailableInnerQuotedTweetTombstoneRule + extends UserUnavailableStateInnerQuotedTweetTombstoneRule( + Epitaph.BlockedBy, + AuthorBlocksViewer) + + object ViewerBlocksAuthorUserUnavailableInnerQuotedTweetInterstitialRule + extends UserUnavailableStateInnerQuotedTweetInterstitialRule( + Reason.ViewerBlocksAuthor, + ViewerBlocksAuthor) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableInnerQuotedTweetViewerBlocksAuthorInterstitialRuleParam) + } + + object ViewerMutesAuthorUserUnavailableInnerQuotedTweetInterstitialRule + extends UserUnavailableStateInnerQuotedTweetInterstitialRule( + Reason.ViewerMutesAuthor, + ViewerMutesAuthor) { + override def enabled: Seq[RuleParam[Boolean]] = + Seq(EnableInnerQuotedTweetViewerMutesAuthorInterstitialRuleParam) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/VisibilityPolicy.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/VisibilityPolicy.scala new file mode 100644 index 000000000..1ff0eaada --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/VisibilityPolicy.scala @@ -0,0 +1,3778 @@ +package com.twitter.visibility.rules + +import com.twitter.visibility.configapi.params.RuleParam +import com.twitter.visibility.configapi.params.RuleParams +import com.twitter.visibility.models.ContentId +import com.twitter.visibility.rules.ConversationControlRules._ +import com.twitter.visibility.rules.FollowerRelations.AuthorMutesViewerRule +import com.twitter.visibility.rules.FollowerRelations.ProtectedViewerRule +import com.twitter.visibility.rules.PolicyLevelRuleParams.ruleParams +import com.twitter.visibility.rules.PublicInterestRules._ +import com.twitter.visibility.rules.SafeSearchTweetRules._ +import com.twitter.visibility.rules.SafeSearchUserRules.SafeSearchNsfwAvatarImageUserLabelRule +import com.twitter.visibility.rules.SafeSearchUserRules._ +import com.twitter.visibility.rules.SpaceRules._ +import com.twitter.visibility.rules.ToxicityReplyFilterRules.ToxicityReplyFilterDropNotificationRule +import com.twitter.visibility.rules.ToxicityReplyFilterRules.ToxicityReplyFilterRule +import com.twitter.visibility.rules.UnsafeSearchTweetRules._ +import com.twitter.visibility.rules.UserUnavailableStateTombstoneRules._ + +abstract class VisibilityPolicy( + val tweetRules: Seq[Rule] = Nil, + val userRules: Seq[Rule] = Nil, + val cardRules: Seq[Rule] = Nil, + val quotedTweetRules: Seq[Rule] = Nil, + val dmRules: Seq[Rule] = Nil, + val dmConversationRules: Seq[Rule] = Nil, + val dmEventRules: Seq[Rule] = Nil, + val spaceRules: Seq[Rule] = Nil, + val userUnavailableStateRules: Seq[Rule] = Nil, + val twitterArticleRules: Seq[Rule] = Nil, + val deletedTweetRules: Seq[Rule] = Nil, + val mediaRules: Seq[Rule] = Nil, + val communityRules: Seq[Rule] = Nil, + val policyRuleParams: Map[Rule, PolicyLevelRuleParams] = Map.empty) { + + def forContentId(contentId: ContentId): Seq[Rule] = + contentId match { + case ContentId.TweetId(_) => tweetRules + case ContentId.UserId(_) => userRules + case ContentId.CardId(_) => cardRules + case ContentId.QuotedTweetRelationship(_, _) => quotedTweetRules + case ContentId.NotificationId(_) => userRules + case ContentId.DmId(_) => dmRules + case ContentId.BlenderTweetId(_) => userRules ++ tweetRules + case ContentId.SpaceId(_) => spaceRules + case ContentId.SpacePlusUserId(_) => spaceRules ++ userRules + case ContentId.DmConversationId(_) => dmConversationRules + case ContentId.DmEventId(_) => dmEventRules + case ContentId.UserUnavailableState(_) => userUnavailableStateRules + case ContentId.TwitterArticleId(_) => twitterArticleRules + case ContentId.DeleteTweetId(_) => deletedTweetRules + case ContentId.MediaId(_) => mediaRules + case ContentId.CommunityId(_) => communityRules + } + + private[visibility] def allRules: Seq[Rule] = + (tweetRules ++ userRules ++ cardRules ++ quotedTweetRules ++ dmRules ++ spaceRules ++ dmConversationRules ++ dmEventRules ++ twitterArticleRules ++ deletedTweetRules ++ mediaRules ++ communityRules) +} + +object VisibilityPolicy { + val baseTweetRules = Seq( + DropCommunityTweetsRule, + DropCommunityTweetCommunityNotVisibleRule, + DropProtectedCommunityTweetsRule, + DropHiddenCommunityTweetsRule, + DropAuthorRemovedCommunityTweetsRule, + SpamTweetLabelRule, + PdnaTweetLabelRule, + BounceTweetLabelRule, + DropExclusiveTweetContentRule, + DropTrustedFriendsTweetContentRule + ) + + val baseTweetTombstoneRules = Seq( + TombstoneCommunityTweetsRule, + TombstoneCommunityTweetCommunityNotVisibleRule, + TombstoneProtectedCommunityTweetsRule, + TombstoneHiddenCommunityTweetsRule, + TombstoneAuthorRemovedCommunityTweetsRule, + SpamTweetLabelTombstoneRule, + PdnaTweetLabelTombstoneRule, + BounceTweetLabelTombstoneRule, + TombstoneExclusiveTweetContentRule, + TombstoneTrustedFriendsTweetContentRule, + ) + + val baseMediaRules = Seq( + ) + + val baseQuotedTweetTombstoneRules = Seq( + BounceQuotedTweetTombstoneRule + ) + + def union[T](rules: Seq[Rule]*): Seq[Rule] = { + if (rules.isEmpty) { + Seq.empty[Rule] + } else { + rules.reduce((a, b) => a ++ b.filterNot(a.contains)) + } + } +} + +case class PolicyLevelRuleParams( + ruleParams: Seq[RuleParam[Boolean]], + force: Boolean = false) {} + +object PolicyLevelRuleParams { + def ruleParams(ruleParams: RuleParam[Boolean]*): PolicyLevelRuleParams = { + PolicyLevelRuleParams(ruleParams) + } + + def ruleParams(force: Boolean, ruleParams: RuleParam[Boolean]*): PolicyLevelRuleParams = { + PolicyLevelRuleParams(ruleParams, force) + } +} + +case object FilterAllPolicy + extends VisibilityPolicy( + tweetRules = Seq(DropAllRule), + userRules = Seq(DropAllRule), + cardRules = Seq(DropAllRule), + quotedTweetRules = Seq(DropAllRule), + dmRules = Seq(DropAllRule), + dmConversationRules = Seq(DropAllRule), + dmEventRules = Seq(DropAllRule), + spaceRules = Seq(DropAllRule), + userUnavailableStateRules = Seq(DropAllRule), + twitterArticleRules = Seq(DropAllRule), + deletedTweetRules = Seq(DropAllRule), + mediaRules = Seq(DropAllRule), + communityRules = Seq(DropAllRule), + ) + +case object FilterNonePolicy extends VisibilityPolicy() + +object ConversationsAdAvoidanceRules { + val tweetRules = Seq( + NsfwHighRecallTweetLabelAvoidRule, + NsfwHighPrecisionTweetLabelAvoidRule, + NsfwTextTweetLabelAvoidRule, + AvoidHighToxicityModelScoreRule, + AvoidReportedTweetModelScoreRule, + NsfwHighPrecisionUserLabelAvoidTweetRule, + TweetNsfwUserAdminAvoidRule, + DoNotAmplifyTweetLabelAvoidRule, + NsfaHighPrecisionTweetLabelAvoidRule, + ) + + val policyRuleParams = Map[Rule, PolicyLevelRuleParams]( + NsfwHighRecallTweetLabelAvoidRule -> ruleParams( + RuleParams.EnableNewAdAvoidanceRulesParam + ), + NsfwHighPrecisionTweetLabelAvoidRule -> ruleParams( + RuleParams.EnableNewAdAvoidanceRulesParam + ), + NsfwTextTweetLabelAvoidRule -> ruleParams(RuleParams.EnableNewAdAvoidanceRulesParam), + AvoidHighToxicityModelScoreRule -> ruleParams(RuleParams.EnableNewAdAvoidanceRulesParam), + AvoidReportedTweetModelScoreRule -> ruleParams(RuleParams.EnableNewAdAvoidanceRulesParam), + NsfwHighPrecisionUserLabelAvoidTweetRule -> ruleParams( + RuleParams.EnableNewAdAvoidanceRulesParam), + TweetNsfwUserAdminAvoidRule -> ruleParams(RuleParams.EnableNewAdAvoidanceRulesParam), + DoNotAmplifyTweetLabelAvoidRule -> ruleParams(RuleParams.EnableNewAdAvoidanceRulesParam), + NsfaHighPrecisionTweetLabelAvoidRule -> ruleParams(RuleParams.EnableNewAdAvoidanceRulesParam), + ) +} + +case object FilterDefaultPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ + Seq( + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule + ) + ) + +case object LimitedEngagementBaseRules + extends VisibilityPolicy( + tweetRules = Seq( + StaleTweetLimitedActionsRule, + LimitRepliesByInvitationConversationRule, + LimitRepliesCommunityConversationRule, + LimitRepliesFollowersConversationRule, + CommunityTweetCommunityNotFoundLimitedActionsRule, + CommunityTweetCommunityDeletedLimitedActionsRule, + CommunityTweetCommunitySuspendedLimitedActionsRule, + CommunityTweetMemberRemovedLimitedActionsRule, + CommunityTweetHiddenLimitedActionsRule, + CommunityTweetMemberLimitedActionsRule, + CommunityTweetNonMemberLimitedActionsRule, + DynamicProductAdLimitedEngagementTweetLabelRule, + TrustedFriendsTweetLimitedEngagementsRule + ) + ) + +case object WritePathLimitedActionsEnforcementPolicy + extends VisibilityPolicy( + tweetRules = Seq( + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule + ) ++ + LimitedEngagementBaseRules.tweetRules + ) + +case object TestPolicy + extends VisibilityPolicy( + tweetRules = Seq( + TestRule + ) + ) + +case object CardsServicePolicy + extends VisibilityPolicy( + cardRules = Seq( + DropProtectedAuthorPollCardRule, + DropCardUriRootDomainDenylistRule + ), + spaceRules = Seq( + SpaceHighToxicityScoreNonFollowerDropRule, + SpaceHatefulHighRecallAllUsersDropRule, + SpaceViolenceHighRecallAllUsersDropRule, + ViewerIsSoftUserDropRule + ), + ) + +case object CardPollVotingPolicy + extends VisibilityPolicy( + cardRules = Seq( + DropProtectedAuthorPollCardRule, + DropCommunityNonMemberPollCardRule + ) + ) + +case object UserTimelineRules { + val UserRules = Seq( + AuthorBlocksViewerDropRule, + ProtectedAuthorDropRule, + SuspendedAuthorRule + ) +} + +case object TimelineLikedByRules { + val UserRules = Seq( + CompromisedNonFollowerWithUqfRule, + EngagementSpammerNonFollowerWithUqfRule, + LowQualityNonFollowerWithUqfRule, + ReadOnlyNonFollowerWithUqfRule, + SpamHighRecallNonFollowerWithUqfRule + ) +} + +case object FollowingAndFollowersUserListPolicy + extends VisibilityPolicy( + userRules = UserTimelineRules.UserRules + ) + +case object FriendsFollowingListPolicy + extends VisibilityPolicy( + userRules = UserTimelineRules.UserRules + ) + +case object ListOwnershipsPolicy + extends VisibilityPolicy( + userRules = UserTimelineRules.UserRules + ) + +case object ListRecommendationsPolicy + extends VisibilityPolicy( + userRules = RecommendationsPolicy.userRules ++ Seq( + DropNsfwUserAuthorRule, + NsfwHighRecallRule, + SearchBlacklistRule, + SearchNsfwTextRule, + ViewerBlocksAuthorRule, + ViewerMutesAuthorRule + ) + ) + +case object ListSearchBaseRules { + + val NonExperimentalSafeSearchMinimalPolicyUserRules: Seq[Rule] = + SafeSearchMinimalPolicy.userRules.filterNot(_.isExperimental) + + val MinimalPolicyUserRules: Seq[Rule] = NonExperimentalSafeSearchMinimalPolicyUserRules + + val BlockMutePolicyUserRules = Seq( + ViewerBlocksAuthorViewerOptInBlockingOnSearchRule, + ViewerMutesAuthorViewerOptInBlockingOnSearchRule + ) + + val StrictPolicyUserRules = Seq( + SafeSearchAbusiveUserLabelRule, + SafeSearchAbusiveHighRecallUserLabelRule, + SafeSearchCompromisedUserLabelRule, + SafeSearchDoNotAmplifyNonFollowersUserLabelRule, + SafeSearchDuplicateContentUserLabelRule, + SafeSearchLowQualityUserLabelRule, + SafeSearchNotGraduatedNonFollowersUserLabelRule, + SafeSearchNsfwHighPrecisionUserLabelRule, + SafeSearchNsfwAvatarImageUserLabelRule, + SafeSearchNsfwBannerImageUserLabelRule, + SafeSearchReadOnlyUserLabelRule, + SafeSearchSearchBlacklistUserLabelRule, + SafeSearchNsfwTextUserLabelRule, + SafeSearchSpamHighRecallUserLabelRule, + SafeSearchDownrankSpamReplyAuthorLabelRule, + SafeSearchNsfwTextAuthorLabelRule, + DropNsfwAdminAuthorViewerOptInFilteringOnSearchRule, + DropNsfwUserAuthorViewerOptInFilteringOnSearchRule, + ) +} + +object SensitiveMediaSettingsTimelineHomeBaseRules { + val policyRuleParams = Map[Rule, PolicyLevelRuleParams]( + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaHomeTimelineRulesParam), + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaHomeTimelineRulesParam), + NsfwReportedHeuristicsAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaHomeTimelineRulesParam), + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaHomeTimelineRulesParam), + NsfwCardImageAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaHomeTimelineRulesParam), + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwHighPrecisionTweetLabelInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsHomeTimelineRulesParam), + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceHighPrecisionInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsHomeTimelineRulesParam), + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwReportedHeuristicsTweetLabelInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsHomeTimelineRulesParam), + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceReportedHeuristicsInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsHomeTimelineRulesParam), + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwCardImageTweetLabelInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsHomeTimelineRulesParam), + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwUserTweetFlagInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsHomeTimelineRulesParam), + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwAdminTweetFlagInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsHomeTimelineRulesParam) + ) +} + +object SensitiveMediaSettingsConversationBaseRules { + val policyRuleParams = Map[Rule, PolicyLevelRuleParams]( + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaConversationRulesParam), + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaConversationRulesParam), + NsfwReportedHeuristicsAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaConversationRulesParam), + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaConversationRulesParam), + NsfwCardImageAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaConversationRulesParam), + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwHighPrecisionTweetLabelInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsConversationRulesParam), + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceHighPrecisionInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsConversationRulesParam), + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwReportedHeuristicsTweetLabelInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsConversationRulesParam), + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceReportedHeuristicsInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsConversationRulesParam), + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwCardImageTweetLabelInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsConversationRulesParam), + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwUserTweetFlagInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsConversationRulesParam), + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwAdminTweetFlagInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsConversationRulesParam) + ) +} + +object SensitiveMediaSettingsProfileTimelineBaseRules { + val policyRuleParams = Map[Rule, PolicyLevelRuleParams]( + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaProfileTimelineRulesParam), + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaProfileTimelineRulesParam), + NsfwReportedHeuristicsAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaProfileTimelineRulesParam), + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaProfileTimelineRulesParam), + NsfwCardImageAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaProfileTimelineRulesParam), + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwHighPrecisionTweetLabelInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsProfileTimelineRulesParam), + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceHighPrecisionInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsProfileTimelineRulesParam), + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwReportedHeuristicsTweetLabelInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsProfileTimelineRulesParam), + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceReportedHeuristicsInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsProfileTimelineRulesParam), + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwCardImageTweetLabelInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsProfileTimelineRulesParam), + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwUserTweetFlagInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsProfileTimelineRulesParam), + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwAdminTweetFlagInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsProfileTimelineRulesParam) + ) +} + +object SensitiveMediaSettingsTweetDetailBaseRules { + val policyRuleParams = Map[Rule, PolicyLevelRuleParams]( + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaTweetDetailRulesParam), + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaTweetDetailRulesParam), + NsfwReportedHeuristicsAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaTweetDetailRulesParam), + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaTweetDetailRulesParam), + NsfwCardImageAllUsersTweetLabelRule -> ruleParams( + RuleParams.EnableLegacySensitiveMediaTweetDetailRulesParam), + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwHighPrecisionTweetLabelInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsTweetDetailRulesParam), + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceHighPrecisionInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsTweetDetailRulesParam), + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwReportedHeuristicsTweetLabelInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsTweetDetailRulesParam), + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceReportedHeuristicsInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsTweetDetailRulesParam), + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwCardImageTweetLabelInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsTweetDetailRulesParam), + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwUserTweetFlagInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsTweetDetailRulesParam), + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwAdminTweetFlagInterstitialRule -> ruleParams( + RuleParams.EnableNewSensitiveMediaSettingsInterstitialsTweetDetailRulesParam) + ) +} + +case object ListSearchPolicy + extends VisibilityPolicy( + userRules = ListSearchBaseRules.MinimalPolicyUserRules ++ + ListSearchBaseRules.BlockMutePolicyUserRules ++ + ListSearchBaseRules.StrictPolicyUserRules + ) + +case object ListSubscriptionsPolicy + extends VisibilityPolicy( + userRules = UserTimelineRules.UserRules + ) + +case object ListMembershipsPolicy + extends VisibilityPolicy( + userRules = UserTimelineRules.UserRules + ) + +case object AllSubscribedListsPolicy + extends VisibilityPolicy( + userRules = UserTimelineRules.UserRules + ) + +case object ListHeaderPolicy + extends VisibilityPolicy( + userRules = Seq( + AuthorBlocksViewerDropRule, + ProtectedAuthorDropRule, + SuspendedAuthorRule + ) + ) + +case object NewUserExperiencePolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ Seq( + AbusiveTweetLabelRule, + LowQualityTweetLabelDropRule, + NsfaHighRecallTweetLabelRule, + NsfwHighPrecisionTweetLabelRule, + GoreAndViolenceHighPrecisionTweetLabelRule, + NsfwReportedHeuristicsTweetLabelRule, + GoreAndViolenceReportedHeuristicsTweetLabelRule, + NsfwCardImageTweetLabelRule, + NsfwHighRecallTweetLabelRule, + NsfwVideoTweetLabelDropRule, + NsfwTextTweetLabelDropRule, + SpamHighRecallTweetLabelDropRule, + DuplicateContentTweetLabelDropRule, + GoreAndViolenceTweetLabelRule, + UntrustedUrlTweetLabelRule, + DownrankSpamReplyTweetLabelRule, + SearchBlacklistTweetLabelRule, + AutomationTweetLabelRule, + DuplicateMentionTweetLabelRule, + BystanderAbusiveTweetLabelRule, + SafetyCrisisLevel3DropRule, + SafetyCrisisLevel4DropRule, + DoNotAmplifyDropRule, + SmyteSpamTweetLabelDropRule, + ), + userRules = Seq( + AbusiveRule, + LowQualityRule, + ReadOnlyRule, + SearchBlacklistRule, + SearchNsfwTextRule, + CompromisedRule, + SpamHighRecallRule, + DuplicateContentRule, + NsfwHighPrecisionRule, + NsfwAvatarImageRule, + NsfwBannerImageRule, + AbusiveHighRecallRule, + DoNotAmplifyNonFollowerRule, + NotGraduatedNonFollowerRule, + LikelyIvsLabelNonFollowerDropUserRule, + DownrankSpamReplyNonAuthorRule, + NsfwTextNonAuthorDropRule + ) + ) + +case object DESHomeTimelinePolicy + extends VisibilityPolicy( + tweetRules = Seq( + DropStaleTweetsRule, + DropAllExclusiveTweetsRule, + DropAllTrustedFriendsTweetsRule, + DropAllCommunityTweetsRule + ) ++ + VisibilityPolicy.baseTweetRules, + userRules = UserTimelineRules.UserRules + ) + +case object DesQuoteTweetTimelinePolicy + extends VisibilityPolicy( + tweetRules = Seq( + DropStaleTweetsRule, + DropAllExclusiveTweetsRule, + DropAllTrustedFriendsTweetsRule + ) ++ ElevatedQuoteTweetTimelinePolicy.tweetRules.diff(Seq(DropStaleTweetsRule)), + userRules = Seq( + ProtectedAuthorDropRule + ), + policyRuleParams = ElevatedQuoteTweetTimelinePolicy.policyRuleParams + ) + +case object DESRealtimeSpamEnrichmentPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ Seq( + LowQualityTweetLabelDropRule, + SpamHighRecallTweetLabelDropRule, + DuplicateContentTweetLabelDropRule, + SearchBlacklistTweetLabelRule, + SmyteSpamTweetLabelDropRule, + DropAllCommunityTweetsRule, + DropAllExclusiveTweetsRule, + DropAllTrustedFriendsTweetsRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule + ) + ) + +case object DESRealtimePolicy + extends VisibilityPolicy( + tweetRules = Seq( + DropAllCommunityTweetsRule, + DropAllExclusiveTweetsRule, + DropAllTrustedFriendsTweetsRule, + DropAllCollabInvitationTweetsRule + ), + userRules = Seq( + DropAllProtectedAuthorRule, + DropProtectedViewerIfPresentRule + ) + ) + +case object DESRetweetingUsersPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ Seq( + DropAllExclusiveTweetsRule, + DropAllTrustedFriendsTweetsRule, + ), + userRules = Seq( + AuthorBlocksViewerDropRule, + ViewerBlocksAuthorRule, + ProtectedAuthorDropRule, + SuspendedAuthorRule + ) + ) + +case object DESTweetLikingUsersPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ Seq( + DropAllExclusiveTweetsRule, + DropAllTrustedFriendsTweetsRule, + ), + userRules = TimelineLikedByRules.UserRules + ) + +case object DESUserBookmarksPolicy + extends VisibilityPolicy( + tweetRules = Seq( + DropAllExclusiveTweetsRule, + DropAllTrustedFriendsTweetsRule, + ) ++ + (VisibilityPolicy.baseTweetRules + ++ Seq(DropAllCommunityTweetsRule) + ++ TimelineProfileRules.tweetRules), + userRules = UserTimelineRules.UserRules + ) + +case object DESUserLikedTweetsPolicy + extends VisibilityPolicy( + tweetRules = Seq( + DropStaleTweetsRule, + DropAllExclusiveTweetsRule, + DropAllTrustedFriendsTweetsRule, + ) ++ + ( + VisibilityPolicy.baseTweetRules ++ + Seq( + DropAllCommunityTweetsRule, + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + ReportedTweetInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + NsfwHighPrecisionTweetLabelAvoidRule, + NsfwHighRecallTweetLabelAvoidRule, + GoreAndViolenceHighPrecisionAvoidAllUsersTweetLabelRule, + NsfwReportedHeuristicsAvoidAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAvoidAllUsersTweetLabelRule, + NsfwCardImageAvoidAllUsersTweetLabelRule, + DoNotAmplifyTweetLabelAvoidRule, + NsfaHighPrecisionTweetLabelAvoidRule, + ) ++ LimitedEngagementBaseRules.tweetRules + ), + userRules = UserTimelineRules.UserRules + ) + +case object DESUserMentionsPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ Seq( + DropAllCommunityTweetsRule, + AuthorBlocksViewerDropRule, + ProtectedAuthorDropRule, + DropAllExclusiveTweetsRule, + DropAllTrustedFriendsTweetsRule, + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + ) ++ LimitedEngagementBaseRules.tweetRules, + userRules = Seq( + SuspendedAuthorRule + ) + ) + +case object DESUserTweetsPolicy + extends VisibilityPolicy( + tweetRules = Seq( + DropStaleTweetsRule, + DropAllExclusiveTweetsRule, + DropAllTrustedFriendsTweetsRule, + ) ++ + (VisibilityPolicy.baseTweetRules + ++ Seq(DropAllCommunityTweetsRule) + ++ TimelineProfileRules.tweetRules), + userRules = UserTimelineRules.UserRules + ) + +case object DevPlatformComplianceStreamPolicy + extends VisibilityPolicy( + tweetRules = Seq( + SpamAllUsersTweetLabelRule, + PdnaAllUsersTweetLabelRule, + BounceAllUsersTweetLabelRule, + AbusePolicyEpisodicTweetLabelComplianceTweetNoticeRule, + ) + ) + +case object DesTweetDetailPolicy + extends VisibilityPolicy( + tweetRules = Seq( + DropAllExclusiveTweetsRule, + DropAllTrustedFriendsTweetsRule, + ) ++ BaseTweetDetailPolicy.tweetRules + ) + +case object DevPlatformGetListTweetsPolicy + extends VisibilityPolicy( + tweetRules = Seq(DropStaleTweetsRule) ++ DesTweetDetailPolicy.tweetRules + ) + +case object FollowerConnectionsPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules, + userRules = Seq( + SpammyFollowerRule + ) + ) + +case object SuperFollowerConnectionsPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules, + userRules = Seq( + SpammyFollowerRule + ) + ) + +case object LivePipelineEngagementCountsPolicy + extends VisibilityPolicy( + tweetRules = Seq( + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + ) ++ LimitedEngagementBaseRules.tweetRules + ) + +case object LiveVideoTimelinePolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ Seq( + AbusiveTweetLabelRule, + AbusiveHighRecallTweetLabelRule, + LowQualityTweetLabelDropRule, + NsfwHighPrecisionTweetLabelRule, + GoreAndViolenceHighPrecisionTweetLabelRule, + NsfwReportedHeuristicsTweetLabelRule, + GoreAndViolenceReportedHeuristicsTweetLabelRule, + NsfwCardImageTweetLabelRule, + NsfwHighRecallTweetLabelRule, + NsfwVideoTweetLabelDropRule, + NsfwTextTweetLabelDropRule, + LiveLowQualityTweetLabelRule, + SpamHighRecallTweetLabelDropRule, + DuplicateContentTweetLabelDropRule, + SearchBlacklistTweetLabelRule, + BystanderAbusiveTweetLabelRule, + SafetyCrisisLevel3DropRule, + SafetyCrisisLevel4DropRule, + DoNotAmplifyDropRule, + SmyteSpamTweetLabelDropRule, + AbusePolicyEpisodicTweetLabelDropRule, + EmergencyDropRule, + ), + userRules = Seq( + AbusiveRule, + LowQualityRule, + ReadOnlyRule, + SearchBlacklistRule, + SearchNsfwTextRule, + CompromisedRule, + NsfwHighPrecisionRule, + NsfwHighRecallRule, + NsfwAvatarImageRule, + NsfwBannerImageRule, + SpamHighRecallRule, + DuplicateContentRule, + LiveLowQualityRule, + EngagementSpammerRule, + EngagementSpammerHighRecallRule, + AbusiveHighRecallRule, + DoNotAmplifyNonFollowerRule, + NotGraduatedNonFollowerRule, + LikelyIvsLabelNonFollowerDropUserRule, + NsfwTextNonAuthorDropRule + ) + ) + +case object MagicRecsPolicyOverrides { + val replacements: Map[Rule, Rule] = Map() + def union(rules: Seq[Rule]*): Seq[Rule] = rules + .map(ar => ar.map(x => replacements.getOrElse(x, x))) + .reduce((a, b) => a ++ b.filterNot(a.contains)) +} + +case object MagicRecsPolicy + extends VisibilityPolicy( + tweetRules = MagicRecsPolicyOverrides.union( + RecommendationsPolicy.tweetRules.filterNot(_ == SafetyCrisisLevel3DropRule), + NotificationsIbisPolicy.tweetRules, + Seq(NsfaHighRecallTweetLabelRule, NsfwHighRecallTweetLabelRule), + Seq( + AuthorBlocksViewerDropRule, + ViewerBlocksAuthorRule, + ViewerMutesAuthorRule + ), + Seq( + DeactivatedAuthorRule, + SuspendedAuthorRule, + TweetNsfwUserDropRule, + TweetNsfwAdminDropRule + ) + ), + userRules = MagicRecsPolicyOverrides.union( + RecommendationsPolicy.userRules, + NotificationsRules.userRules + ) + ) + +case object MagicRecsV2Policy + extends VisibilityPolicy( + tweetRules = MagicRecsPolicyOverrides.union( + MagicRecsPolicy.tweetRules, + NotificationsWriterTweetHydratorPolicy.tweetRules + ), + userRules = MagicRecsPolicyOverrides.union( + MagicRecsPolicy.userRules, + NotificationsWriterV2Policy.userRules + ) + ) + +case object MagicRecsAggressivePolicy + extends VisibilityPolicy( + tweetRules = MagicRecsPolicy.tweetRules, + userRules = MagicRecsPolicy.userRules + ) + +case object MagicRecsAggressiveV2Policy + extends VisibilityPolicy( + tweetRules = MagicRecsV2Policy.tweetRules, + userRules = MagicRecsV2Policy.userRules + ) + +case object MinimalPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules, + userRules = Seq( + TsViolationRule + ) + ) + +case object ModeratedTweetsTimelinePolicy + extends VisibilityPolicy( + tweetRules = TweetDetailPolicy.tweetRules.diff( + Seq( + AuthorBlocksViewerDropRule, + MutedKeywordForTweetRepliesInterstitialRule, + ReportedTweetInterstitialRule)), + policyRuleParams = TweetDetailPolicy.policyRuleParams + ) + +case object MomentsPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ + Seq( + AuthorBlocksViewerUnspecifiedRule, + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + ) ++ LimitedEngagementBaseRules.tweetRules + ) + +case object NearbyTimelinePolicy + extends VisibilityPolicy( + tweetRules = SearchBlenderRules.tweetRelevanceRules, + userRules = SearchBlenderRules.userBaseRules + ) + +private object NotificationsRules { + val tweetRules: Seq[Rule] = + DropStaleTweetsRule +: VisibilityPolicy.baseTweetRules + + val userRules: Seq[Rule] = Seq( + AbusiveRule, + LowQualityRule, + ReadOnlyRule, + CompromisedRule, + SpamHighRecallRule, + DuplicateContentRule, + AbusiveHighRecallRule, + EngagementSpammerNonFollowerWithUqfRule, + EngagementSpammerHighRecallNonFollowerWithUqfRule, + DownrankSpamReplyNonFollowerWithUqfRule + ) +} + +case object NotificationsIbisPolicy + extends VisibilityPolicy( + tweetRules = + VisibilityPolicy.baseTweetRules ++ Seq( + AbusiveUqfNonFollowerTweetLabelRule, + LowQualityTweetLabelDropRule, + ToxicityReplyFilterDropNotificationRule, + NsfwHighPrecisionTweetLabelRule, + GoreAndViolenceHighPrecisionTweetLabelRule, + NsfwReportedHeuristicsTweetLabelRule, + GoreAndViolenceReportedHeuristicsTweetLabelRule, + NsfwCardImageTweetLabelRule, + NsfwVideoTweetLabelDropRule, + NsfwTextTweetLabelDropRule, + SpamHighRecallTweetLabelDropRule, + DuplicateContentTweetLabelDropRule, + DuplicateMentionTweetLabelRule, + LowQualityMentionTweetLabelRule, + UntrustedUrlUqfNonFollowerTweetLabelRule, + DownrankSpamReplyUqfNonFollowerTweetLabelRule, + SafetyCrisisAnyLevelDropRule, + DoNotAmplifyDropRule, + SmyteSpamTweetLabelDropRule, + AbusePolicyEpisodicTweetLabelDropRule, + EmergencyDropRule, + ), + userRules = NotificationsRules.userRules ++ Seq( + DoNotAmplifyNonFollowerRule, + LikelyIvsLabelNonFollowerDropUserRule, + NsfwTextNonAuthorDropRule + ) + ) + +case object NotificationsReadPolicy + extends VisibilityPolicy( + tweetRules = NotificationsRules.tweetRules, + userRules = NotificationsRules.userRules + ) + +case object NotificationsTimelineDeviceFollowPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules, + userRules = Seq( + AuthorBlocksViewerDropRule, + ViewerBlocksAuthorRule, + CompromisedRule + ) + ) + +case object NotificationsWritePolicy + extends VisibilityPolicy( + tweetRules = NotificationsRules.tweetRules, + userRules = NotificationsRules.userRules + ) + +case object NotificationsWriterV2Policy + extends VisibilityPolicy( + userRules = + Seq( + AuthorBlocksViewerDropRule, + DeactivatedAuthorRule, + ErasedAuthorRule, + ProtectedAuthorDropRule, + SuspendedAuthorRule, + DeactivatedViewerRule, + SuspendedViewerRule, + ViewerBlocksAuthorRule, + ViewerMutesAndDoesNotFollowAuthorRule, + ViewerIsUnmentionedRule, + NoConfirmedEmailRule, + NoConfirmedPhoneRule, + NoDefaultProfileImageRule, + NoNewUsersRule, + NoNotFollowedByRule, + OnlyPeopleIFollowRule + ) ++ + NotificationsRules.userRules + ) + +case object NotificationsWriterTweetHydratorPolicy + extends VisibilityPolicy( + tweetRules = NotificationsRules.tweetRules ++ + Seq( + LowQualityTweetLabelDropRule, + SpamHighRecallTweetLabelDropRule, + DuplicateContentTweetLabelDropRule, + DuplicateMentionUqfTweetLabelRule, + LowQualityMentionTweetLabelRule, + SmyteSpamTweetLabelDropRule, + ToxicityReplyFilterDropNotificationRule, + AbusiveUqfNonFollowerTweetLabelRule, + UntrustedUrlUqfNonFollowerTweetLabelRule, + DownrankSpamReplyUqfNonFollowerTweetLabelRule, + ViewerHasMatchingMutedKeywordForNotificationsRule, + NsfwCardImageAllUsersTweetLabelRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + ) ++ LimitedEngagementBaseRules.tweetRules + ) + +case object NotificationsPlatformPolicy + extends VisibilityPolicy( + tweetRules = NotificationsWriterTweetHydratorPolicy.tweetRules, + userRules = NotificationsWriterV2Policy.userRules + ) + +case object NotificationsPlatformPushPolicy + extends VisibilityPolicy( + tweetRules = NotificationsIbisPolicy.tweetRules, + userRules = Seq(ViewerMutesAuthorRule) + ++ NotificationsIbisPolicy.userRules + ) + +case object QuoteTweetTimelinePolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ Seq( + DropStaleTweetsRule, + AbusiveTweetLabelRule, + LowQualityTweetLabelDropRule, + NsfwHighPrecisionTweetLabelRule, + GoreAndViolenceHighPrecisionTweetLabelRule, + NsfwReportedHeuristicsTweetLabelRule, + GoreAndViolenceReportedHeuristicsTweetLabelRule, + NsfwCardImageTweetLabelRule, + NsfwHighRecallTweetLabelRule, + NsfwVideoTweetLabelDropRule, + NsfwTextTweetLabelDropRule, + SpamHighRecallTweetLabelDropRule, + DuplicateContentTweetLabelDropRule, + GoreAndViolenceTweetLabelRule, + UntrustedUrlTweetLabelRule, + DownrankSpamReplyTweetLabelRule, + SearchBlacklistTweetLabelRule, + AutomationTweetLabelRule, + DuplicateMentionTweetLabelRule, + BystanderAbusiveTweetLabelRule, + SmyteSpamTweetLabelDropRule, + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + ) ++ LimitedEngagementBaseRules.tweetRules, + userRules = Seq( + AbusiveRule, + LowQualityRule, + ReadOnlyRule, + SearchBlacklistRule, + SearchNsfwTextRule, + CompromisedRule, + SpamHighRecallRule, + DuplicateContentRule, + NsfwHighPrecisionRule, + NsfwAvatarImageRule, + NsfwBannerImageRule, + AbusiveHighRecallRule, + DownrankSpamReplyNonAuthorRule, + NsfwTextNonAuthorDropRule + ) + ) + +case object ElevatedQuoteTweetTimelinePolicy + extends VisibilityPolicy( + tweetRules = + TweetDetailPolicy.tweetRules.diff( + Seq( + MutedKeywordForQuotedTweetTweetDetailInterstitialRule, + ReportedTweetInterstitialRule)), + policyRuleParams = TweetDetailPolicy.policyRuleParams + ) + +case object EmbedsPublicInterestNoticePolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ Seq( + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + ) + ) + +case object RecommendationsPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ + Seq( + AbusiveTweetLabelRule, + LowQualityTweetLabelDropRule, + NsfwHighPrecisionTweetLabelRule, + GoreAndViolenceHighPrecisionTweetLabelRule, + NsfwReportedHeuristicsTweetLabelRule, + GoreAndViolenceReportedHeuristicsTweetLabelRule, + NsfwCardImageTweetLabelRule, + NsfwVideoTweetLabelDropRule, + NsfwTextTweetLabelDropRule, + SpamHighRecallTweetLabelDropRule, + DuplicateContentTweetLabelDropRule, + GoreAndViolenceTweetLabelRule, + BystanderAbusiveTweetLabelRule, + DoNotAmplifyDropRule, + SafetyCrisisLevel3DropRule, + SmyteSpamTweetLabelDropRule, + AbusePolicyEpisodicTweetLabelDropRule, + EmergencyDropRule, + ), + userRules = Seq( + DropNsfwAdminAuthorRule, + AbusiveRule, + LowQualityRule, + ReadOnlyRule, + CompromisedRule, + RecommendationsBlacklistRule, + SpamHighRecallRule, + DuplicateContentRule, + NsfwHighPrecisionRule, + NsfwNearPerfectAuthorRule, + NsfwBannerImageRule, + NsfwAvatarImageRule, + EngagementSpammerRule, + EngagementSpammerHighRecallRule, + AbusiveHighRecallRule, + DoNotAmplifyNonFollowerRule, + NotGraduatedNonFollowerRule, + LikelyIvsLabelNonFollowerDropUserRule, + NsfwTextNonAuthorDropRule + ) + ) + +case object RecosVideoPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ Seq( + AbusiveTweetLabelRule, + LowQualityTweetLabelDropRule, + NsfwHighPrecisionTweetLabelRule, + GoreAndViolenceHighPrecisionTweetLabelRule, + NsfwReportedHeuristicsTweetLabelRule, + GoreAndViolenceReportedHeuristicsTweetLabelRule, + NsfwCardImageTweetLabelRule, + NsfwHighRecallTweetLabelRule, + NsfwVideoTweetLabelDropRule, + NsfwTextTweetLabelDropRule, + SpamHighRecallTweetLabelDropRule, + DuplicateContentTweetLabelDropRule, + BystanderAbusiveTweetLabelRule, + SmyteSpamTweetLabelDropRule, + ), + userRules = Seq(NsfwTextNonAuthorDropRule) + ) + +case object RepliesGroupingPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ + Seq( + LowQualityTweetLabelDropRule, + SpamHighRecallTweetLabelDropRule, + DuplicateContentTweetLabelDropRule, + DeciderableSpamHighRecallAuthorLabelDropRule, + SmyteSpamTweetLabelDropRule, + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + MutedKeywordForTweetRepliesInterstitialRule, + ReportedTweetInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + NsfwHighPrecisionTweetLabelAvoidRule, + NsfwHighRecallTweetLabelAvoidRule, + GoreAndViolenceHighPrecisionAvoidAllUsersTweetLabelRule, + NsfwReportedHeuristicsAvoidAdPlacementAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAvoidAdPlacementAllUsersTweetLabelRule, + NsfwCardImageAvoidAdPlacementAllUsersTweetLabelRule, + DoNotAmplifyTweetLabelAvoidRule, + NsfaHighPrecisionTweetLabelAvoidRule, + ) ++ LimitedEngagementBaseRules.tweetRules, + userRules = Seq( + LowQualityRule, + ReadOnlyRule, + LowQualityHighRecallRule, + CompromisedRule, + DeciderableSpamHighRecallRule + ) + ) + +case object ReturningUserExperiencePolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ Seq( + AbusiveTweetLabelRule, + LowQualityTweetLabelDropRule, + NsfaHighRecallTweetLabelRule, + NsfwHighPrecisionTweetLabelRule, + GoreAndViolenceHighPrecisionTweetLabelRule, + NsfwReportedHeuristicsTweetLabelRule, + GoreAndViolenceReportedHeuristicsTweetLabelRule, + NsfwCardImageTweetLabelRule, + NsfwHighRecallTweetLabelRule, + NsfwVideoTweetLabelDropRule, + NsfwTextTweetLabelDropRule, + NsfwTextTweetLabelTopicsDropRule, + SpamHighRecallTweetLabelDropRule, + DuplicateContentTweetLabelDropRule, + GoreAndViolenceTweetLabelRule, + UntrustedUrlTweetLabelRule, + DownrankSpamReplyTweetLabelRule, + SearchBlacklistTweetLabelRule, + AutomationTweetLabelRule, + DuplicateMentionTweetLabelRule, + BystanderAbusiveTweetLabelRule, + SmyteSpamTweetLabelDropRule, + SafetyCrisisLevel3DropRule, + SafetyCrisisLevel4DropRule, + DoNotAmplifyDropRule, + AbusePolicyEpisodicTweetLabelDropRule, + EmergencyDropRule, + ) ++ LimitedEngagementBaseRules.tweetRules, + userRules = Seq( + AbusiveRule, + LowQualityRule, + ReadOnlyRule, + SearchBlacklistRule, + SearchNsfwTextRule, + CompromisedRule, + SpamHighRecallRule, + DuplicateContentRule, + NsfwHighPrecisionRule, + NsfwAvatarImageRule, + NsfwBannerImageRule, + AbusiveHighRecallRule, + DoNotAmplifyNonFollowerRule, + NotGraduatedNonFollowerRule, + LikelyIvsLabelNonFollowerDropUserRule, + DownrankSpamReplyNonAuthorRule, + NsfwTextNonAuthorDropRule, + DropNsfwUserAuthorRule, + NsfwHighRecallRule + ) + ) + +case object ReturningUserExperienceFocalTweetPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ + Seq( + AuthorBlocksViewerDropRule, + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + MutedKeywordForTweetRepliesInterstitialRule, + ViewerMutesAuthorInterstitialRule, + ReportedTweetInterstitialRule, + ) ++ LimitedEngagementBaseRules.tweetRules + ) + +case object RevenuePolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ + Seq( + AbusiveTweetLabelRule, + BystanderAbusiveTweetLabelRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule + ) + ) + +case object SafeSearchMinimalPolicy + extends VisibilityPolicy( + tweetRules = Seq( + DropOuterCommunityTweetsRule, + ) ++ VisibilityPolicy.baseTweetRules ++ Seq( + LowQualityTweetLabelDropRule, + HighProactiveTosScoreTweetLabelDropSearchRule, + SpamHighRecallTweetLabelDropRule, + DuplicateContentTweetLabelDropRule, + SearchBlacklistTweetLabelRule, + SearchBlacklistHighRecallTweetLabelDropRule, + SafetyCrisisLevel3DropRule, + SafetyCrisisLevel4DropRule, + DoNotAmplifyDropRule, + SmyteSpamTweetLabelDropRule, + ) ++ + Seq( + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + ) ++ LimitedEngagementBaseRules.tweetRules + ++ SearchBlenderRules.tweetAvoidRules, + userRules = Seq( + LowQualityRule, + ReadOnlyRule, + CompromisedRule, + SpamHighRecallRule, + SearchBlacklistRule, + SearchNsfwTextRule, + DuplicateContentRule, + DoNotAmplifyNonFollowerRule, + SearchLikelyIvsLabelNonFollowerDropUserRule + ) + ) + +case object SearchHydrationPolicy + extends VisibilityPolicy( + tweetRules = Seq( + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + ReportedTweetInterstitialSearchRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + ) ++ LimitedEngagementBaseRules.tweetRules + ) + +case object SearchBlenderRules { + val limitedEngagementBaseRules: Seq[Rule] = LimitedEngagementBaseRules.tweetRules + + val tweetAvoidRules: Seq[Rule] = + Seq( + NsfwHighPrecisionTweetLabelAvoidRule, + NsfwHighRecallTweetLabelAvoidRule, + GoreAndViolenceHighPrecisionAvoidAllUsersTweetLabelRule, + NsfwReportedHeuristicsAvoidAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAvoidAllUsersTweetLabelRule, + NsfwCardImageAvoidAllUsersTweetLabelRule, + SearchAvoidTweetNsfwAdminRule, + SearchAvoidTweetNsfwUserRule, + DoNotAmplifyTweetLabelAvoidRule, + NsfaHighPrecisionTweetLabelAvoidRule, + ) + + val basicBlockMuteRules: Seq[Rule] = Seq( + AuthorBlocksViewerDropRule, + ViewerBlocksAuthorViewerOptInBlockingOnSearchRule, + ViewerMutesAuthorViewerOptInBlockingOnSearchRule + ) + + val tweetRelevanceRules: Seq[Rule] = + Seq( + DropOuterCommunityTweetsRule, + DropStaleTweetsRule, + ) ++ VisibilityPolicy.baseTweetRules ++ Seq( + SafeSearchAbusiveTweetLabelRule, + LowQualityTweetLabelDropRule, + HighProactiveTosScoreTweetLabelDropSearchRule, + HighPSpammyTweetScoreSearchTweetLabelDropRule, + HighSpammyTweetContentScoreSearchTopTweetLabelDropRule, + HighSpammyTweetContentScoreTrendsTopTweetLabelDropRule, + SafeSearchNsfwHighPrecisionTweetLabelRule, + SafeSearchGoreAndViolenceHighPrecisionTweetLabelRule, + SafeSearchNsfwReportedHeuristicsTweetLabelRule, + SafeSearchGoreAndViolenceReportedHeuristicsTweetLabelRule, + SafeSearchNsfwCardImageTweetLabelRule, + SafeSearchNsfwHighRecallTweetLabelRule, + SafeSearchNsfwVideoTweetLabelRule, + SafeSearchNsfwTextTweetLabelRule, + SpamHighRecallTweetLabelDropRule, + DuplicateContentTweetLabelDropRule, + SafeSearchGoreAndViolenceTweetLabelRule, + SafeSearchUntrustedUrlTweetLabelRule, + SafeSearchDownrankSpamReplyTweetLabelRule, + SearchBlacklistTweetLabelRule, + SearchBlacklistHighRecallTweetLabelDropRule, + SmyteSpamTweetLabelDropSearchRule, + CopypastaSpamAllViewersSearchTweetLabelRule, + ) ++ basicBlockMuteRules ++ + Seq( + SafeSearchAutomationNonFollowerTweetLabelRule, + SafeSearchDuplicateMentionNonFollowerTweetLabelRule, + SafeSearchBystanderAbusiveTweetLabelRule, + SafetyCrisisLevel3DropRule, + SafetyCrisisLevel4DropRule, + DoNotAmplifyDropRule, + SearchIpiSafeSearchWithoutUserInQueryDropRule, + SearchEdiSafeSearchWithoutUserInQueryDropRule, + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + UnsafeSearchNsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + UnsafeSearchGoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + UnsafeSearchNsfwReportedHeuristicsAllUsersTweetLabelRule, + UnsafeSearchGoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + UnsafeSearchNsfwCardImageAllUsersTweetLabelRule, + ) ++ + limitedEngagementBaseRules ++ + tweetAvoidRules + + VisibilityPolicy.baseTweetRules ++ Seq( + SafeSearchAbusiveTweetLabelRule, + LowQualityTweetLabelDropRule, + HighProactiveTosScoreTweetLabelDropSearchRule, + HighSpammyTweetContentScoreSearchLatestTweetLabelDropRule, + HighSpammyTweetContentScoreTrendsLatestTweetLabelDropRule, + SafeSearchNsfwHighPrecisionTweetLabelRule, + SafeSearchGoreAndViolenceHighPrecisionTweetLabelRule, + SafeSearchNsfwReportedHeuristicsTweetLabelRule, + SafeSearchGoreAndViolenceReportedHeuristicsTweetLabelRule, + SafeSearchNsfwCardImageTweetLabelRule, + SafeSearchNsfwHighRecallTweetLabelRule, + SafeSearchNsfwVideoTweetLabelRule, + SafeSearchNsfwTextTweetLabelRule, + SpamHighRecallTweetLabelDropRule, + DuplicateContentTweetLabelDropRule, + SafeSearchGoreAndViolenceTweetLabelRule, + SafeSearchUntrustedUrlTweetLabelRule, + SafeSearchDownrankSpamReplyTweetLabelRule, + SearchBlacklistTweetLabelRule, + SearchBlacklistHighRecallTweetLabelDropRule, + SmyteSpamTweetLabelDropSearchRule, + CopypastaSpamNonFollowerSearchTweetLabelRule, + ) ++ + basicBlockMuteRules ++ + Seq( + SafeSearchAutomationNonFollowerTweetLabelRule, + SafeSearchDuplicateMentionNonFollowerTweetLabelRule, + SafeSearchBystanderAbusiveTweetLabelRule, + SafetyCrisisLevel3DropRule, + SafetyCrisisLevel4DropRule, + SearchIpiSafeSearchWithoutUserInQueryDropRule, + SearchEdiSafeSearchWithoutUserInQueryDropRule, + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + UnsafeSearchNsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + UnsafeSearchGoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + UnsafeSearchNsfwReportedHeuristicsAllUsersTweetLabelRule, + UnsafeSearchGoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + UnsafeSearchNsfwCardImageAllUsersTweetLabelRule, + ) ++ limitedEngagementBaseRules ++ tweetAvoidRules + + val userBaseRules: Seq[ConditionWithUserLabelRule] = Seq( + SafeSearchAbusiveUserLabelRule, + LowQualityRule, + ReadOnlyRule, + SearchBlacklistRule, + CompromisedRule, + SpamHighRecallRule, + DuplicateContentRule, + DoNotAmplifyNonFollowerRule, + SearchLikelyIvsLabelNonFollowerDropUserRule, + SafeSearchNsfwHighPrecisionUserLabelRule, + SafeSearchNsfwAvatarImageUserLabelRule, + SafeSearchNsfwBannerImageUserLabelRule, + SafeSearchAbusiveHighRecallUserLabelRule, + SafeSearchDownrankSpamReplyAuthorLabelRule, + SafeSearchNotGraduatedNonFollowersUserLabelRule, + SafeSearchNsfwTextAuthorLabelRule + ) + + val userRules: Seq[ConditionWithUserLabelRule] = userBaseRules + + val userRelevanceBaseRules = userBaseRules ++ basicBlockMuteRules + + val userRelevanceRules = userRelevanceBaseRules + + val userRecencyBaseRules = userBaseRules.filterNot( + Seq(DoNotAmplifyNonFollowerRule, SearchLikelyIvsLabelNonFollowerDropUserRule).contains + ) ++ basicBlockMuteRules + + val searchQueryMatchesTweetAuthorRules: Seq[ConditionWithUserLabelRule] = + userBaseRules + + val basicBlockMutePolicyRuleParam: Map[Rule, PolicyLevelRuleParams] = + SearchBlenderRules.basicBlockMuteRules + .map(rule => rule -> ruleParams(RuleParams.EnableSearchBasicBlockMuteRulesParam)).toMap +} + +case object SearchBlenderUserRulesPolicy + extends VisibilityPolicy( + userRules = SearchBlenderRules.userRules + ) + +case object SearchLatestUserRulesPolicy + extends VisibilityPolicy( + userRules = SearchLatestPolicy.userRules + ) + +case object UserSearchSrpPolicy + extends VisibilityPolicy( + userRules = Seq( + AuthorBlocksViewerDropRule, + ViewerBlocksAuthorViewerOptInBlockingOnSearchRule, + ViewerMutesAuthorViewerOptInBlockingOnSearchRule, + DropNsfwAdminAuthorViewerOptInFilteringOnSearchRule, + SafeSearchAbusiveUserLabelRule, + SafeSearchHighRecallUserLabelRule, + SafeSearchNsfwNearPerfectAuthorRule, + SafeSearchNsfwHighPrecisionUserLabelRule, + SafeSearchNsfwAvatarImageUserLabelRule, + SafeSearchNsfwBannerImageUserLabelRule, + SafeSearchAbusiveHighRecallUserLabelRule, + SafeSearchNsfwTextAuthorLabelRule + ) + ) + +case object UserSearchTypeaheadPolicy + extends VisibilityPolicy( + userRules = Seq( + SafeSearchAbusiveUserLabelRule, + SafeSearchHighRecallUserLabelRule, + SafeSearchNsfwNearPerfectAuthorRule, + SafeSearchNsfwHighPrecisionUserLabelRule, + SafeSearchNsfwAvatarImageUserLabelRule, + SafeSearchNsfwBannerImageUserLabelRule, + SafeSearchAbusiveHighRecallUserLabelRule, + SafeSearchNsfwTextAuthorLabelRule + ), + tweetRules = Seq(DropAllRule) + ) + +case object SearchMixerSrpMinimalPolicy + extends VisibilityPolicy( + userRules = Seq( + AuthorBlocksViewerDropRule, + ViewerBlocksAuthorViewerOptInBlockingOnSearchRule, + ViewerMutesAuthorViewerOptInBlockingOnSearchRule + ) + ) + +case object SearchMixerSrpStrictPolicy + extends VisibilityPolicy( + userRules = Seq( + AuthorBlocksViewerDropRule, + ViewerBlocksAuthorViewerOptInBlockingOnSearchRule, + ViewerMutesAuthorViewerOptInBlockingOnSearchRule, + DropNsfwAdminAuthorViewerOptInFilteringOnSearchRule, + NsfwNearPerfectAuthorRule, + NsfwHighPrecisionRule, + NsfwHighRecallRule, + NsfwSensitiveRule, + NsfwAvatarImageRule, + NsfwBannerImageRule + ) ++ SearchBlenderRules.searchQueryMatchesTweetAuthorRules + .diff(Seq(SafeSearchNotGraduatedNonFollowersUserLabelRule)) + ) + +case object SearchPeopleSrpPolicy + extends VisibilityPolicy( + userRules = SearchBlenderRules.searchQueryMatchesTweetAuthorRules + ) + +case object SearchPeopleTypeaheadPolicy + extends VisibilityPolicy( + userRules = SearchBlenderRules.searchQueryMatchesTweetAuthorRules + .diff( + Seq( + SafeSearchNotGraduatedNonFollowersUserLabelRule + )), + tweetRules = Seq(DropAllRule) + ) + +case object SearchPhotoPolicy + extends VisibilityPolicy( + tweetRules = SearchBlenderRules.tweetRelevanceRules, + userRules = SearchBlenderRules.userRelevanceRules, + policyRuleParams = SearchBlenderRules.basicBlockMutePolicyRuleParam + ) + +case object SearchTrendTakeoverPromotedTweetPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules + ) + +case object SearchVideoPolicy + extends VisibilityPolicy( + tweetRules = SearchBlenderRules.tweetRelevanceRules, + userRules = SearchBlenderRules.userRelevanceRules, + policyRuleParams = SearchBlenderRules.basicBlockMutePolicyRuleParam + ) + +case object SearchLatestPolicy + extends VisibilityPolicy( + tweetRules = SearchBlenderRules.tweetRecencyRules, + userRules = SearchBlenderRules.userRecencyBaseRules, + policyRuleParams = SearchBlenderRules.basicBlockMutePolicyRuleParam + ) + +case object SearchTopPolicy + extends VisibilityPolicy( + tweetRules = SearchBlenderRules.tweetRelevanceRules, + userRules = Seq(SpammyUserModelHighPrecisionDropTweetRule) ++ + SearchBlenderRules.basicBlockMuteRules ++ + SearchBlenderRules.searchQueryMatchesTweetAuthorRules, + policyRuleParams = SearchBlenderRules.basicBlockMutePolicyRuleParam + ) + +case object SearchTopQigPolicy + extends VisibilityPolicy( + tweetRules = BaseQigPolicy.tweetRules ++ + Seq( + UnsafeSearchGoreAndViolenceHighPrecisionAllUsersTweetLabelDropRule, + UnsafeSearchGoreAndViolenceReportedHeuristicsAllUsersTweetLabelDropRule, + UnsafeSearchNsfwCardImageAllUsersTweetLabelDropRule, + UnsafeSearchNsfwReportedHeuristicsAllUsersTweetLabelDropRule, + UnsafeSearchNsfwHighPrecisionAllUsersTweetLabelDropRule + ) ++ + SearchTopPolicy.tweetRules.diff( + Seq( + SearchIpiSafeSearchWithoutUserInQueryDropRule, + SearchEdiSafeSearchWithoutUserInQueryDropRule, + HighSpammyTweetContentScoreTrendsTopTweetLabelDropRule, + UnsafeSearchNsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + UnsafeSearchGoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + UnsafeSearchGoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + UnsafeSearchNsfwCardImageAllUsersTweetLabelRule, + UnsafeSearchNsfwReportedHeuristicsAllUsersTweetLabelRule + ) ++ + SearchTopPolicy.tweetRules.intersect(BaseQigPolicy.tweetRules)), + userRules = BaseQigPolicy.userRules ++ Seq( + DropNsfwAdminAuthorViewerOptInFilteringOnSearchRule, + NsfwNearPerfectAuthorRule, + ) ++ SearchTopPolicy.userRules.diff( + SearchTopPolicy.userRules.intersect(BaseQigPolicy.userRules)), + policyRuleParams = SearchBlenderRules.basicBlockMutePolicyRuleParam + ) + +case object SafeSearchStrictPolicy + extends VisibilityPolicy( + tweetRules = Seq( + DropOuterCommunityTweetsRule, + ) ++ VisibilityPolicy.baseTweetRules ++ Seq( + AbusiveTweetLabelRule, + LowQualityTweetLabelDropRule, + HighProactiveTosScoreTweetLabelDropSearchRule, + NsfwHighPrecisionTweetLabelRule, + GoreAndViolenceHighPrecisionTweetLabelRule, + NsfwReportedHeuristicsTweetLabelRule, + GoreAndViolenceReportedHeuristicsTweetLabelRule, + NsfwCardImageTweetLabelRule, + NsfwHighRecallTweetLabelRule, + NsfwVideoTweetLabelDropRule, + NsfwTextTweetLabelDropRule, + SpamHighRecallTweetLabelDropRule, + DuplicateContentTweetLabelDropRule, + GoreAndViolenceTweetLabelRule, + UntrustedUrlTweetLabelRule, + DownrankSpamReplyTweetLabelRule, + SearchBlacklistTweetLabelRule, + SearchBlacklistHighRecallTweetLabelDropRule, + AutomationTweetLabelRule, + DuplicateMentionTweetLabelRule, + BystanderAbusiveTweetLabelRule, + SafetyCrisisLevel3DropRule, + SafetyCrisisLevel4DropRule, + DoNotAmplifyDropRule, + SmyteSpamTweetLabelDropRule, + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + ) ++ LimitedEngagementBaseRules.tweetRules + ++ SearchBlenderRules.tweetAvoidRules, + userRules = Seq( + AbusiveRule, + LowQualityRule, + ReadOnlyRule, + SearchBlacklistRule, + SearchNsfwTextRule, + CompromisedRule, + SpamHighRecallRule, + DuplicateContentRule, + NsfwHighPrecisionRule, + NsfwAvatarImageRule, + NsfwBannerImageRule, + AbusiveHighRecallRule, + DoNotAmplifyNonFollowerRule, + NotGraduatedNonFollowerRule, + SearchLikelyIvsLabelNonFollowerDropUserRule, + DownrankSpamReplyNonAuthorRule, + NsfwTextNonAuthorDropRule, + ) + ) + +case object StickersTimelinePolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules, + userRules = Seq( + AbusiveRule, + LowQualityRule, + ReadOnlyRule, + CompromisedRule, + SearchBlacklistRule, + SearchNsfwTextRule, + DuplicateContentRule, + EngagementSpammerRule, + EngagementSpammerHighRecallRule, + NsfwSensitiveRule, + SpamHighRecallRule, + AbusiveHighRecallRule + ) + ) + +case object StratoExtLimitedEngagementsPolicy + extends VisibilityPolicy( + tweetRules = + VisibilityPolicy.baseTweetRules ++ LimitedEngagementBaseRules.tweetRules + ) + +case object InternalPromotedContentPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules + ) + +case object StreamServicesPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ Seq( + AbusiveTweetLabelRule, + LowQualityTweetLabelDropRule, + NsfwHighPrecisionTweetLabelRule, + GoreAndViolenceHighPrecisionTweetLabelRule, + NsfwReportedHeuristicsTweetLabelRule, + GoreAndViolenceReportedHeuristicsTweetLabelRule, + NsfwCardImageTweetLabelRule, + NsfwVideoTweetLabelDropRule, + NsfwTextTweetLabelDropRule, + SpamHighRecallTweetLabelDropRule, + DuplicateContentTweetLabelDropRule, + BystanderAbusiveTweetLabelRule, + SmyteSpamTweetLabelDropRule + ), + userRules = Seq(NsfwTextNonAuthorDropRule) + ) + +case object SuperLikePolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ Seq( + AbusePolicyEpisodicTweetLabelDropRule, + EmergencyDropRule, + NsfwHighPrecisionTweetLabelRule, + GoreAndViolenceHighPrecisionTweetLabelRule, + NsfwReportedHeuristicsTweetLabelRule, + GoreAndViolenceReportedHeuristicsTweetLabelRule, + NsfwCardImageTweetLabelRule, + NsfwVideoTweetLabelDropRule, + NsfwTextTweetLabelDropRule + ), + userRules = Seq(NsfwTextNonAuthorDropRule) + ) + +case object TimelineFocalTweetPolicy + extends VisibilityPolicy( + tweetRules = Seq( + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + ) ++ LimitedEngagementBaseRules.tweetRules + ) + +case object TimelineBookmarkPolicy + extends VisibilityPolicy( + tweetRules = + Seq( + DropCommunityTweetsRule, + DropCommunityTweetCommunityNotVisibleRule, + DropProtectedCommunityTweetsRule, + DropHiddenCommunityTweetsRule, + DropAuthorRemovedCommunityTweetsRule, + SpamTweetLabelRule, + PdnaTweetLabelRule, + BounceOuterTweetTombstoneRule, + BounceQuotedTweetTombstoneRule, + DropExclusiveTweetContentRule, + DropTrustedFriendsTweetContentRule, + ) ++ + Seq( + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + ViewerBlocksAuthorInnerQuotedTweetInterstitialRule, + ViewerMutesAuthorInnerQuotedTweetInterstitialRule, + NsfwCardImageAllUsersTweetLabelRule, + ) ++ LimitedEngagementBaseRules.tweetRules, + deletedTweetRules = Seq( + TombstoneBounceDeletedTweetRule, + TombstoneDeletedQuotedTweetRule + ), + userUnavailableStateRules = Seq( + SuspendedUserUnavailableTweetTombstoneRule, + DeactivatedUserUnavailableTweetTombstoneRule, + OffBoardedUserUnavailableTweetTombstoneRule, + ErasedUserUnavailableTweetTombstoneRule, + ProtectedUserUnavailableTweetTombstoneRule, + AuthorBlocksViewerUserUnavailableInnerQuotedTweetTombstoneRule, + UserUnavailableTweetTombstoneRule, + ViewerBlocksAuthorUserUnavailableInnerQuotedTweetInterstitialRule, + ViewerMutesAuthorUserUnavailableInnerQuotedTweetInterstitialRule + ), + ) + +case object TimelineListsPolicy + extends VisibilityPolicy( + tweetRules = + Seq( + DropOuterCommunityTweetsRule, + DropStaleTweetsRule, + ) ++ + VisibilityPolicy.baseTweetRules ++ + Seq( + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + ) ++ LimitedEngagementBaseRules.tweetRules + ) + +case object TimelineFavoritesPolicy + extends VisibilityPolicy( + tweetRules = + Seq( + DropOuterCommunityTweetsRule, + DropStaleTweetsRule, + ) + ++ TimelineProfileRules.baseTweetRules + ++ Seq( + DynamicProductAdDropTweetLabelRule, + NsfwHighPrecisionTombstoneInnerQuotedTweetLabelRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.AdultMediaNsfwHighPrecisionTweetLabelDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.ViolentMediaGoreAndViolenceHighPrecisionDropSettingLeveTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.AdultMediaNsfwReportedHeuristicsTweetLabelDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.ViolentMediaGoreAndViolenceReportedHeuristicsDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.AdultMediaNsfwCardImageTweetLabelDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.OtherSensitiveMediaNsfwUserTweetFlagDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.OtherSensitiveMediaNsfwAdminTweetFlagDropSettingLevelTombstoneRule, + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + ReportedTweetInterstitialRule, + ViewerMutesAuthorInterstitialRule, + ViewerBlocksAuthorInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwHighPrecisionTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceHighPrecisionInterstitialRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwReportedHeuristicsTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceReportedHeuristicsInterstitialRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwCardImageTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwUserTweetFlagInterstitialRule, + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwAdminTweetFlagInterstitialRule, + NsfwHighPrecisionTweetLabelAvoidRule, + NsfwHighRecallTweetLabelAvoidRule, + GoreAndViolenceHighPrecisionAvoidAllUsersTweetLabelRule, + NsfwReportedHeuristicsAvoidAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAvoidAllUsersTweetLabelRule, + NsfwCardImageAvoidAllUsersTweetLabelRule, + DoNotAmplifyTweetLabelAvoidRule, + NsfaHighPrecisionTweetLabelAvoidRule, + ) ++ LimitedEngagementBaseRules.tweetRules, + deletedTweetRules = Seq( + TombstoneDeletedQuotedTweetRule, + TombstoneBounceDeletedQuotedTweetRule + ), + userUnavailableStateRules = Seq( + SuspendedUserUnavailableInnerQuotedTweetTombstoneRule, + DeactivatedUserUnavailableInnerQuotedTweetTombstoneRule, + OffBoardedUserUnavailableInnerQuotedTweetTombstoneRule, + ErasedUserUnavailableInnerQuotedTweetTombstoneRule, + ProtectedUserUnavailableInnerQuotedTweetTombstoneRule, + AuthorBlocksViewerUserUnavailableInnerQuotedTweetTombstoneRule, + ViewerBlocksAuthorUserUnavailableInnerQuotedTweetInterstitialRule, + ViewerMutesAuthorUserUnavailableInnerQuotedTweetInterstitialRule + ), + policyRuleParams = SensitiveMediaSettingsProfileTimelineBaseRules.policyRuleParams + ) + +case object ProfileMixerFavoritesPolicy + extends VisibilityPolicy( + tweetRules = Seq( + DropStaleTweetsRule, + DropExclusiveTweetContentRule, + DropOuterCommunityTweetsRule, + ), + deletedTweetRules = Seq( + TombstoneDeletedQuotedTweetRule, + TombstoneBounceDeletedQuotedTweetRule + ) + ) + +case object TimelineMediaPolicy + extends VisibilityPolicy( + TimelineProfileRules.baseTweetRules + ++ Seq( + NsfwHighPrecisionTombstoneInnerQuotedTweetLabelRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.AdultMediaNsfwHighPrecisionTweetLabelDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.ViolentMediaGoreAndViolenceHighPrecisionDropSettingLeveTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.AdultMediaNsfwReportedHeuristicsTweetLabelDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.ViolentMediaGoreAndViolenceReportedHeuristicsDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.AdultMediaNsfwCardImageTweetLabelDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.OtherSensitiveMediaNsfwUserTweetFlagDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.OtherSensitiveMediaNsfwAdminTweetFlagDropSettingLevelTombstoneRule, + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + ReportedTweetInterstitialRule, + ViewerMutesAuthorInnerQuotedTweetInterstitialRule, + ViewerBlocksAuthorInnerQuotedTweetInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwHighPrecisionTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceHighPrecisionInterstitialRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwReportedHeuristicsTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceReportedHeuristicsInterstitialRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwCardImageTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwUserTweetFlagInterstitialRule, + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwAdminTweetFlagInterstitialRule, + NsfwHighPrecisionTweetLabelAvoidRule, + NsfwHighRecallTweetLabelAvoidRule, + GoreAndViolenceHighPrecisionAvoidAllUsersTweetLabelRule, + NsfwReportedHeuristicsAvoidAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAvoidAllUsersTweetLabelRule, + NsfwCardImageAvoidAllUsersTweetLabelRule, + DoNotAmplifyTweetLabelAvoidRule, + NsfaHighPrecisionTweetLabelAvoidRule, + ) ++ LimitedEngagementBaseRules.tweetRules, + deletedTweetRules = Seq( + TombstoneBounceDeletedOuterTweetRule, + TombstoneDeletedQuotedTweetRule, + TombstoneBounceDeletedQuotedTweetRule + ), + userUnavailableStateRules = Seq( + SuspendedUserUnavailableInnerQuotedTweetTombstoneRule, + DeactivatedUserUnavailableInnerQuotedTweetTombstoneRule, + OffBoardedUserUnavailableInnerQuotedTweetTombstoneRule, + ErasedUserUnavailableInnerQuotedTweetTombstoneRule, + ProtectedUserUnavailableInnerQuotedTweetTombstoneRule, + AuthorBlocksViewerUserUnavailableInnerQuotedTweetTombstoneRule, + ViewerBlocksAuthorUserUnavailableInnerQuotedTweetInterstitialRule, + ViewerMutesAuthorUserUnavailableInnerQuotedTweetInterstitialRule + ), + policyRuleParams = SensitiveMediaSettingsProfileTimelineBaseRules.policyRuleParams + ) + +case object ProfileMixerMediaPolicy + extends VisibilityPolicy( + tweetRules = Seq( + DropStaleTweetsRule, + DropExclusiveTweetContentRule + ), + deletedTweetRules = Seq( + TombstoneBounceDeletedOuterTweetRule, + TombstoneDeletedQuotedTweetRule, + TombstoneBounceDeletedQuotedTweetRule + ) + ) + +object TimelineProfileRules { + + val baseTweetRules: Seq[Rule] = Seq( + TombstoneCommunityTweetsRule, + TombstoneCommunityTweetCommunityNotVisibleRule, + TombstoneProtectedCommunityTweetsRule, + TombstoneHiddenCommunityTweetsRule, + TombstoneAuthorRemovedCommunityTweetsRule, + SpamQuotedTweetLabelTombstoneRule, + SpamTweetLabelRule, + PdnaQuotedTweetLabelTombstoneRule, + PdnaTweetLabelRule, + BounceTweetLabelTombstoneRule, + TombstoneExclusiveQuotedTweetContentRule, + DropExclusiveTweetContentRule, + DropTrustedFriendsTweetContentRule + ) + + val tweetRules: Seq[Rule] = + Seq( + DynamicProductAdDropTweetLabelRule, + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + ReportedTweetInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + NsfwHighPrecisionTweetLabelAvoidRule, + NsfwHighRecallTweetLabelAvoidRule, + GoreAndViolenceHighPrecisionAvoidAllUsersTweetLabelRule, + NsfwReportedHeuristicsAvoidAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAvoidAllUsersTweetLabelRule, + NsfwCardImageAvoidAllUsersTweetLabelRule, + NsfwTextTweetLabelAvoidRule, + DoNotAmplifyTweetLabelAvoidRule, + NsfaHighPrecisionTweetLabelAvoidRule, + ) ++ LimitedEngagementBaseRules.tweetRules + + val tweetTombstoneRules: Seq[Rule] = + Seq( + DynamicProductAdDropTweetLabelRule, + NsfwHighPrecisionInnerQuotedTweetLabelRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.AdultMediaNsfwHighPrecisionTweetLabelDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.ViolentMediaGoreAndViolenceHighPrecisionDropSettingLeveTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.AdultMediaNsfwReportedHeuristicsTweetLabelDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.ViolentMediaGoreAndViolenceReportedHeuristicsDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.AdultMediaNsfwCardImageTweetLabelDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.OtherSensitiveMediaNsfwUserTweetFlagDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.OtherSensitiveMediaNsfwAdminTweetFlagDropSettingLevelTombstoneRule, + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + ReportedTweetInterstitialRule, + ViewerMutesAuthorInnerQuotedTweetInterstitialRule, + ViewerBlocksAuthorInnerQuotedTweetInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwHighPrecisionTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceHighPrecisionInterstitialRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwReportedHeuristicsTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceReportedHeuristicsInterstitialRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwCardImageTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwUserTweetFlagInterstitialRule, + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwAdminTweetFlagInterstitialRule, + NsfwHighPrecisionTweetLabelAvoidRule, + NsfwHighRecallTweetLabelAvoidRule, + GoreAndViolenceHighPrecisionAvoidAllUsersTweetLabelRule, + NsfwReportedHeuristicsAvoidAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAvoidAllUsersTweetLabelRule, + NsfwCardImageAvoidAllUsersTweetLabelRule, + DoNotAmplifyTweetLabelAvoidRule, + NsfaHighPrecisionTweetLabelAvoidRule, + ) ++ LimitedEngagementBaseRules.tweetRules +} + +case object TimelineProfilePolicy + extends VisibilityPolicy( + tweetRules = + Seq( + DropOuterCommunityTweetsRule, + DropStaleTweetsRule, + ) + ++ TimelineProfileRules.baseTweetRules + ++ TimelineProfileRules.tweetTombstoneRules, + deletedTweetRules = Seq( + TombstoneBounceDeletedOuterTweetRule, + TombstoneDeletedQuotedTweetRule, + TombstoneBounceDeletedQuotedTweetRule, + ), + userUnavailableStateRules = Seq( + SuspendedUserUnavailableInnerQuotedTweetTombstoneRule, + DeactivatedUserUnavailableInnerQuotedTweetTombstoneRule, + OffBoardedUserUnavailableInnerQuotedTweetTombstoneRule, + ErasedUserUnavailableInnerQuotedTweetTombstoneRule, + ProtectedUserUnavailableInnerQuotedTweetTombstoneRule, + AuthorBlocksViewerUserUnavailableInnerQuotedTweetTombstoneRule, + ViewerBlocksAuthorUserUnavailableInnerQuotedTweetInterstitialRule, + ViewerMutesAuthorUserUnavailableInnerQuotedTweetInterstitialRule + ), + policyRuleParams = SensitiveMediaSettingsProfileTimelineBaseRules.policyRuleParams + ) + +case object TimelineProfileAllPolicy + extends VisibilityPolicy( + TimelineProfileRules.baseTweetRules + ++ TimelineProfileRules.tweetTombstoneRules, + deletedTweetRules = Seq( + TombstoneBounceDeletedOuterTweetRule, + TombstoneDeletedQuotedTweetRule, + TombstoneBounceDeletedQuotedTweetRule, + ), + userUnavailableStateRules = Seq( + SuspendedUserUnavailableInnerQuotedTweetTombstoneRule, + DeactivatedUserUnavailableInnerQuotedTweetTombstoneRule, + OffBoardedUserUnavailableInnerQuotedTweetTombstoneRule, + ErasedUserUnavailableInnerQuotedTweetTombstoneRule, + ProtectedUserUnavailableInnerQuotedTweetTombstoneRule, + AuthorBlocksViewerUserUnavailableInnerQuotedTweetTombstoneRule, + ViewerBlocksAuthorUserUnavailableInnerQuotedTweetInterstitialRule, + ViewerMutesAuthorUserUnavailableInnerQuotedTweetInterstitialRule + ), + policyRuleParams = SensitiveMediaSettingsProfileTimelineBaseRules.policyRuleParams + ) + +case object TimelineProfileSuperFollowsPolicy + extends VisibilityPolicy( + tweetRules = + Seq( + DropOuterCommunityTweetsRule + ) ++ + VisibilityPolicy.baseTweetRules ++ + TimelineProfileRules.tweetRules + ) + +case object TimelineReactiveBlendingPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ + Seq( + ViewerHasMatchingMutedKeywordForHomeTimelineRule, + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + ) ++ LimitedEngagementBaseRules.tweetRules + ) + +case object TimelineHomePolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseQuotedTweetTombstoneRules ++ + VisibilityPolicy.baseTweetRules ++ + Seq( + NullcastedTweetRule, + DropOuterCommunityTweetsRule, + DynamicProductAdDropTweetLabelRule, + MutedRetweetsRule, + DropAllAuthorRemovedCommunityTweetsRule, + DropAllHiddenCommunityTweetsRule, + AbusePolicyEpisodicTweetLabelDropRule, + EmergencyDropRule, + SafetyCrisisLevel4DropRule, + ViewerHasMatchingMutedKeywordForHomeTimelineRule, + SensitiveMediaTweetDropRules.AdultMediaNsfwHighPrecisionTweetLabelDropRule, + SensitiveMediaTweetDropRules.ViolentMediaGoreAndViolenceHighPrecisionDropRule, + SensitiveMediaTweetDropRules.AdultMediaNsfwReportedHeuristicsTweetLabelDropRule, + SensitiveMediaTweetDropRules.ViolentMediaGoreAndViolenceReportedHeuristicsDropRule, + SensitiveMediaTweetDropRules.AdultMediaNsfwCardImageTweetLabelDropRule, + SensitiveMediaTweetDropRules.OtherSensitiveMediaNsfwUserTweetFlagDropRule, + SensitiveMediaTweetDropRules.OtherSensitiveMediaNsfwAdminTweetFlagDropRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwHighPrecisionTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceHighPrecisionInterstitialRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwReportedHeuristicsTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceReportedHeuristicsInterstitialRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwCardImageTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwUserTweetFlagInterstitialRule, + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwAdminTweetFlagInterstitialRule, + NsfwHighPrecisionTweetLabelAvoidRule, + NsfwHighRecallTweetLabelAvoidRule, + GoreAndViolenceHighPrecisionAvoidAllUsersTweetLabelRule, + NsfwReportedHeuristicsAvoidAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAvoidAllUsersTweetLabelRule, + NsfwCardImageAvoidAllUsersTweetLabelRule, + DoNotAmplifyTweetLabelAvoidRule, + NsfaHighPrecisionTweetLabelAvoidRule, + ) + ++ + LimitedEngagementBaseRules.tweetRules, + userRules = Seq( + ViewerMutesAuthorRule, + ViewerBlocksAuthorRule, + DeciderableAuthorBlocksViewerDropRule + ), + policyRuleParams = SensitiveMediaSettingsTimelineHomeBaseRules.policyRuleParams + ) + +case object BaseTimelineHomePolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseQuotedTweetTombstoneRules ++ + VisibilityPolicy.baseTweetRules ++ + Seq( + NullcastedTweetRule, + DropOuterCommunityTweetsRule, + DynamicProductAdDropTweetLabelRule, + MutedRetweetsRule, + DropAllAuthorRemovedCommunityTweetsRule, + DropAllHiddenCommunityTweetsRule, + AbusePolicyEpisodicTweetLabelDropRule, + EmergencyDropRule, + SafetyCrisisLevel4DropRule, + ViewerHasMatchingMutedKeywordForHomeTimelineRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + NsfwHighPrecisionTweetLabelAvoidRule, + NsfwHighRecallTweetLabelAvoidRule, + GoreAndViolenceHighPrecisionAvoidAllUsersTweetLabelRule, + NsfwReportedHeuristicsAvoidAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAvoidAllUsersTweetLabelRule, + NsfwCardImageAvoidAllUsersTweetLabelRule, + DoNotAmplifyTweetLabelAvoidRule, + NsfaHighPrecisionTweetLabelAvoidRule, + ) + ++ + LimitedEngagementBaseRules.tweetRules, + userRules = Seq( + ViewerMutesAuthorRule, + ViewerBlocksAuthorRule, + DeciderableAuthorBlocksViewerDropRule + ) + ) + +case object TimelineHomeHydrationPolicy + extends VisibilityPolicy( + tweetRules = + VisibilityPolicy.baseQuotedTweetTombstoneRules ++ + VisibilityPolicy.baseTweetRules ++ + Seq( + SensitiveMediaTweetDropRules.AdultMediaNsfwHighPrecisionTweetLabelDropRule, + SensitiveMediaTweetDropRules.ViolentMediaGoreAndViolenceHighPrecisionDropRule, + SensitiveMediaTweetDropRules.AdultMediaNsfwReportedHeuristicsTweetLabelDropRule, + SensitiveMediaTweetDropRules.ViolentMediaGoreAndViolenceReportedHeuristicsDropRule, + SensitiveMediaTweetDropRules.AdultMediaNsfwCardImageTweetLabelDropRule, + SensitiveMediaTweetDropRules.OtherSensitiveMediaNsfwUserTweetFlagDropRule, + SensitiveMediaTweetDropRules.OtherSensitiveMediaNsfwAdminTweetFlagDropRule, + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwHighPrecisionTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceHighPrecisionInterstitialRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwReportedHeuristicsTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceReportedHeuristicsInterstitialRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwCardImageTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwUserTweetFlagInterstitialRule, + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwAdminTweetFlagInterstitialRule, + NsfaHighPrecisionTweetLabelAvoidRule, + NsfwHighPrecisionTweetLabelAvoidRule, + NsfwHighRecallTweetLabelAvoidRule, + ) ++ LimitedEngagementBaseRules.tweetRules, + policyRuleParams = SensitiveMediaSettingsTimelineHomeBaseRules.policyRuleParams + ) + +case object TimelineHomeLatestPolicy + extends VisibilityPolicy( + tweetRules = + VisibilityPolicy.baseQuotedTweetTombstoneRules ++ + VisibilityPolicy.baseTweetRules ++ + Seq( + NullcastedTweetRule, + DropOuterCommunityTweetsRule, + DynamicProductAdDropTweetLabelRule, + MutedRetweetsRule, + ViewerHasMatchingMutedKeywordForHomeTimelineRule, + SensitiveMediaTweetDropRules.AdultMediaNsfwHighPrecisionTweetLabelDropRule, + SensitiveMediaTweetDropRules.ViolentMediaGoreAndViolenceHighPrecisionDropRule, + SensitiveMediaTweetDropRules.AdultMediaNsfwReportedHeuristicsTweetLabelDropRule, + SensitiveMediaTweetDropRules.ViolentMediaGoreAndViolenceReportedHeuristicsDropRule, + SensitiveMediaTweetDropRules.AdultMediaNsfwCardImageTweetLabelDropRule, + SensitiveMediaTweetDropRules.OtherSensitiveMediaNsfwUserTweetFlagDropRule, + SensitiveMediaTweetDropRules.OtherSensitiveMediaNsfwAdminTweetFlagDropRule, + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwHighPrecisionTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceHighPrecisionInterstitialRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwReportedHeuristicsTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceReportedHeuristicsInterstitialRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwCardImageTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwUserTweetFlagInterstitialRule, + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwAdminTweetFlagInterstitialRule, + NsfwHighPrecisionTweetLabelAvoidRule, + NsfwHighRecallTweetLabelAvoidRule, + GoreAndViolenceHighPrecisionAvoidAllUsersTweetLabelRule, + NsfwReportedHeuristicsAvoidAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAvoidAllUsersTweetLabelRule, + NsfwCardImageAvoidAllUsersTweetLabelRule, + DoNotAmplifyTweetLabelAvoidRule, + NsfaHighPrecisionTweetLabelAvoidRule, + ) + ++ + LimitedEngagementBaseRules.tweetRules, + userRules = Seq( + ViewerMutesAuthorRule, + ViewerBlocksAuthorRule, + DeciderableAuthorBlocksViewerDropRule + ), + policyRuleParams = SensitiveMediaSettingsTimelineHomeBaseRules.policyRuleParams + ) + +case object TimelineModeratedTweetsHydrationPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ + Seq( + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + ) ++ LimitedEngagementBaseRules.tweetRules + ) + +case object SignalsReactionsPolicy + extends VisibilityPolicy( + tweetRules = Seq( + AuthorBlocksViewerDropRule + ) ++ LimitedEngagementBaseRules.tweetRules + ) + +case object SignalsTweetReactingUsersPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules :+ + NsfwVideoTweetLabelDropRule :+ + NsfwTextAllUsersTweetLabelDropRule, + userRules = Seq( + CompromisedNonFollowerWithUqfRule, + EngagementSpammerNonFollowerWithUqfRule, + LowQualityNonFollowerWithUqfRule, + ReadOnlyNonFollowerWithUqfRule, + SpamHighRecallNonFollowerWithUqfRule, + AuthorBlocksViewerDropRule, + ProtectedAuthorDropRule, + SuspendedAuthorRule, + NsfwTextNonAuthorDropRule + ) + ) + +case object SocialProofPolicy + extends VisibilityPolicy( + tweetRules = FilterDefaultPolicy.tweetRules, + userRules = Seq( + ProtectedAuthorDropRule, + SuspendedAuthorRule, + AuthorBlocksViewerDropRule, + ViewerBlocksAuthorRule + ) + ) + +case object TimelineLikedByPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules :+ + NsfwVideoTweetLabelDropRule :+ + NsfwTextAllUsersTweetLabelDropRule, + userRules = TimelineLikedByRules.UserRules :+ NsfwTextNonAuthorDropRule + ) + +case object TimelineRetweetedByPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules :+ + NsfwVideoTweetLabelDropRule :+ + NsfwTextAllUsersTweetLabelDropRule, + userRules = Seq( + CompromisedNonFollowerWithUqfRule, + EngagementSpammerNonFollowerWithUqfRule, + LowQualityNonFollowerWithUqfRule, + ReadOnlyNonFollowerWithUqfRule, + SpamHighRecallNonFollowerWithUqfRule, + NsfwTextNonAuthorDropRule + ) + ) + +case object TimelineSuperLikedByPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules :+ + NsfwVideoTweetLabelDropRule :+ + NsfwTextAllUsersTweetLabelDropRule, + userRules = Seq( + CompromisedNonFollowerWithUqfRule, + EngagementSpammerNonFollowerWithUqfRule, + LowQualityNonFollowerWithUqfRule, + ReadOnlyNonFollowerWithUqfRule, + SpamHighRecallNonFollowerWithUqfRule, + NsfwTextNonAuthorDropRule + ) + ) + +case object TimelineContentControlsPolicy + extends VisibilityPolicy( + tweetRules = TopicsLandingPageTopicRecommendationsPolicy.tweetRules, + userRules = TopicsLandingPageTopicRecommendationsPolicy.userRules + ) + +case object TimelineConversationsPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ + Seq( + AbusiveNonFollowerTweetLabelRule, + LowQualityTweetLabelDropRule, + SpamHighRecallTweetLabelDropRule, + DuplicateContentTweetLabelDropRule, + BystanderAbusiveNonFollowerTweetLabelRule, + UntrustedUrlAllViewersTweetLabelRule, + DownrankSpamReplyAllViewersTweetLabelRule, + SmyteSpamTweetLabelDropRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.AdultMediaNsfwHighPrecisionTweetLabelDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.ViolentMediaGoreAndViolenceHighPrecisionDropSettingLeveTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.AdultMediaNsfwReportedHeuristicsTweetLabelDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.ViolentMediaGoreAndViolenceReportedHeuristicsDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.AdultMediaNsfwCardImageTweetLabelDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.OtherSensitiveMediaNsfwUserTweetFlagDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.OtherSensitiveMediaNsfwAdminTweetFlagDropSettingLevelTombstoneRule, + MutedKeywordForTweetRepliesInterstitialRule, + ReportedTweetInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwHighPrecisionTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceHighPrecisionInterstitialRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwReportedHeuristicsTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceReportedHeuristicsInterstitialRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwCardImageTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwUserTweetFlagInterstitialRule, + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwAdminTweetFlagInterstitialRule, + AbusiveHighRecallNonFollowerTweetLabelRule, + ) ++ LimitedEngagementBaseRules.tweetRules, + userRules = Seq( + AbusiveRule, + LowQualityRule, + ReadOnlyRule, + LowQualityHighRecallRule, + CompromisedRule, + SpamHighRecallRule, + AbusiveHighRecallRule, + DownrankSpamReplyAllViewersRule, + ), + policyRuleParams = SensitiveMediaSettingsConversationBaseRules.policyRuleParams + ) + +case object TimelineFollowingActivityPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ + Seq( + AbusiveTweetLabelRule, + BystanderAbusiveTweetLabelRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + ) ++ LimitedEngagementBaseRules.tweetRules + ) + +case object TimelineInjectionPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ Seq( + NsfwHighPrecisionTweetLabelRule, + GoreAndViolenceHighPrecisionTweetLabelRule, + NsfwReportedHeuristicsTweetLabelRule, + GoreAndViolenceReportedHeuristicsTweetLabelRule, + NsfwCardImageTweetLabelRule, + NsfwHighRecallTweetLabelRule, + NsfwVideoTweetLabelDropRule, + NsfwTextTweetLabelDropRule, + SafetyCrisisLevel2DropRule, + SafetyCrisisLevel3DropRule, + SafetyCrisisLevel4DropRule, + DoNotAmplifyDropRule, + HighProactiveTosScoreTweetLabelDropRule + ), + userRules = Seq( + DoNotAmplifyNonFollowerRule, + NotGraduatedNonFollowerRule, + LikelyIvsLabelNonFollowerDropUserRule, + NsfwTextNonAuthorDropRule + ) + ) + +case object TimelineMentionsPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ + Seq( + LowQualityTweetLabelDropRule, + SpamHighRecallTweetLabelDropRule, + DuplicateContentTweetLabelDropRule, + DuplicateMentionUqfTweetLabelRule, + LowQualityMentionTweetLabelRule, + SmyteSpamTweetLabelDropRule, + ToxicityReplyFilterDropNotificationRule, + AbusiveUqfNonFollowerTweetLabelRule, + UntrustedUrlUqfNonFollowerTweetLabelRule, + DownrankSpamReplyUqfNonFollowerTweetLabelRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + ) ++ LimitedEngagementBaseRules.tweetRules, + userRules = Seq( + AbusiveRule, + LowQualityRule, + ReadOnlyRule, + CompromisedRule, + SpamHighRecallRule, + DuplicateContentRule, + AbusiveHighRecallRule, + EngagementSpammerNonFollowerWithUqfRule, + EngagementSpammerHighRecallNonFollowerWithUqfRule, + DownrankSpamReplyNonFollowerWithUqfRule + ) + ) + +case object TweetEngagersPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules, + userRules = Seq( + CompromisedNonFollowerWithUqfRule, + EngagementSpammerNonFollowerWithUqfRule, + LowQualityNonFollowerWithUqfRule, + ReadOnlyNonFollowerWithUqfRule, + SpamHighRecallNonFollowerWithUqfRule + ) + ) + +case object TweetWritesApiPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ Seq( + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + ) ++ LimitedEngagementBaseRules.tweetRules + ) + +case object QuotedTweetRulesPolicy + extends VisibilityPolicy( + quotedTweetRules = Seq( + DeactivatedAuthorRule, + ErasedAuthorRule, + OffboardedAuthorRule, + SuspendedAuthorRule, + AuthorBlocksOuterAuthorRule, + ViewerBlocksAuthorRule, + AuthorBlocksViewerDropRule, + ViewerMutesAndDoesNotFollowAuthorRule, + ProtectedQuoteTweetAuthorRule + ) + ) + +case object TweetDetailPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ + Seq( + AuthorBlocksViewerDropRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.AdultMediaNsfwHighPrecisionTweetLabelDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.ViolentMediaGoreAndViolenceHighPrecisionDropSettingLeveTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.AdultMediaNsfwReportedHeuristicsTweetLabelDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.ViolentMediaGoreAndViolenceReportedHeuristicsDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.AdultMediaNsfwCardImageTweetLabelDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.OtherSensitiveMediaNsfwUserTweetFlagDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.OtherSensitiveMediaNsfwAdminTweetFlagDropSettingLevelTombstoneRule, + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwHighPrecisionTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceHighPrecisionInterstitialRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwReportedHeuristicsTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceReportedHeuristicsInterstitialRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwCardImageTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwUserTweetFlagInterstitialRule, + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwAdminTweetFlagInterstitialRule, + NsfwHighPrecisionTweetLabelAvoidRule, + NsfwHighRecallTweetLabelAvoidRule, + GoreAndViolenceHighPrecisionAvoidAllUsersTweetLabelRule, + NsfwReportedHeuristicsAvoidAdPlacementAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAvoidAdPlacementAllUsersTweetLabelRule, + NsfwCardImageAvoidAdPlacementAllUsersTweetLabelRule, + DoNotAmplifyTweetLabelAvoidRule, + NsfaHighPrecisionTweetLabelAvoidRule, + MutedKeywordForQuotedTweetTweetDetailInterstitialRule, + ) + ++ LimitedEngagementBaseRules.tweetRules, + policyRuleParams = SensitiveMediaSettingsTweetDetailBaseRules.policyRuleParams + ) + +case object BaseTweetDetailPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ + Seq( + AuthorBlocksViewerDropRule, + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + NsfwHighPrecisionTweetLabelAvoidRule, + NsfwHighRecallTweetLabelAvoidRule, + GoreAndViolenceHighPrecisionAvoidAllUsersTweetLabelRule, + NsfwReportedHeuristicsAvoidAdPlacementAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAvoidAdPlacementAllUsersTweetLabelRule, + NsfwCardImageAvoidAdPlacementAllUsersTweetLabelRule, + DoNotAmplifyTweetLabelAvoidRule, + NsfaHighPrecisionTweetLabelAvoidRule, + MutedKeywordForQuotedTweetTweetDetailInterstitialRule, + ) + ++ LimitedEngagementBaseRules.tweetRules + ) + +case object TweetDetailWithInjectionsHydrationPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ + Seq( + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + MutedKeywordForQuotedTweetTweetDetailInterstitialRule, + ReportedTweetInterstitialRule, + ) ++ LimitedEngagementBaseRules.tweetRules, + userRules = UserTimelineRules.UserRules + ) + +case object TweetDetailNonTooPolicy + extends VisibilityPolicy( + tweetRules = Seq( + DropAllExclusiveTweetsRule, + DropAllTrustedFriendsTweetsRule, + ) ++ BaseTweetDetailPolicy.tweetRules + ) + +case object RecosWritePathPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ Seq( + AbusiveTweetLabelRule, + LowQualityTweetLabelDropRule, + NsfwHighPrecisionTweetLabelRule, + GoreAndViolenceHighPrecisionTweetLabelRule, + NsfwReportedHeuristicsTweetLabelRule, + GoreAndViolenceReportedHeuristicsTweetLabelRule, + NsfwCardImageTweetLabelRule, + NsfwVideoTweetLabelDropRule, + NsfwTextTweetLabelDropRule, + SpamHighRecallTweetLabelDropRule, + DuplicateContentTweetLabelDropRule, + BystanderAbusiveTweetLabelRule, + SmyteSpamTweetLabelDropRule + ), + userRules = Seq(NsfwTextNonAuthorDropRule) + ) + +case object BrandSafetyPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ Seq( + NsfwVideoTweetLabelDropRule, + NsfwTextTweetLabelDropRule, + NsfaHighRecallTweetLabelInterstitialRule + ), + userRules = Seq(NsfwTextNonAuthorDropRule) + ) + +case object VideoAdsPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules + ) + +case object AppealsPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ + Seq( + NsfwCardImageAllUsersTweetLabelRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + ) + ) + +case object TimelineConversationsDownrankingPolicy + extends VisibilityPolicy( + tweetRules = Seq( + HighToxicityScoreDownrankAbusiveQualitySectionRule, + UntrustedUrlConversationsTweetLabelRule, + DownrankSpamReplyConversationsTweetLabelRule, + DownrankSpamReplyConversationsAuthorLabelRule, + HighProactiveTosScoreTweetLabelDownrankingRule, + SafetyCrisisLevel3SectionRule, + SafetyCrisisLevel4SectionRule, + DoNotAmplifySectionRule, + DoNotAmplifySectionUserRule, + NotGraduatedConversationsAuthorLabelRule, + HighSpammyTweetContentScoreConvoDownrankAbusiveQualityRule, + HighCryptospamScoreConvoDownrankAbusiveQualityRule, + CopypastaSpamAbusiveQualityTweetLabelRule, + HighToxicityScoreDownrankLowQualitySectionRule, + HighPSpammyTweetScoreDownrankLowQualitySectionRule, + RitoActionedTweetDownrankLowQualitySectionRule, + HighToxicityScoreDownrankHighQualitySectionRule, + ) + ) + +case object TimelineConversationsDownrankingMinimalPolicy + extends VisibilityPolicy( + tweetRules = Seq( + HighProactiveTosScoreTweetLabelDownrankingRule + ) + ) + +case object TimelineHomeRecommendationsPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.union( + RecommendationsPolicy.tweetRules.filter( + _ != NsfwHighPrecisionTweetLabelRule + ), + Seq( + SafetyCrisisLevel2DropRule, + SafetyCrisisLevel3DropRule, + SafetyCrisisLevel4DropRule, + HighProactiveTosScoreTweetLabelDropRule, + NsfwHighRecallTweetLabelRule, + ), + BaseTimelineHomePolicy.tweetRules, + ), + userRules = VisibilityPolicy.union( + RecommendationsPolicy.userRules, + BaseTimelineHomePolicy.userRules + ) + ) + +case object TimelineHomeTopicFollowRecommendationsPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.union( + Seq( + SearchBlacklistTweetLabelRule, + GoreAndViolenceTopicHighRecallTweetLabelRule, + NsfwHighRecallTweetLabelRule, + ), + RecommendationsPolicy.tweetRules + .filterNot( + Seq( + NsfwHighPrecisionTweetLabelRule, + ).contains), + BaseTimelineHomePolicy.tweetRules + ), + userRules = VisibilityPolicy.union( + RecommendationsPolicy.userRules, + BaseTimelineHomePolicy.userRules + ) + ) + +case object TimelineScorerPolicy + extends VisibilityPolicy( + tweetRules = Seq( + AllowAllRule + ) + ) + +case object FollowedTopicsTimelinePolicy + extends VisibilityPolicy( + userRules = Seq( + AuthorBlocksViewerDropRule, + ProtectedAuthorDropRule, + SuspendedAuthorRule + ) + ) + +case object TopicsLandingPageTopicRecommendationsPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.union( + Seq( + SearchBlacklistTweetLabelRule, + GoreAndViolenceTopicHighRecallTweetLabelRule, + NsfwHighRecallTweetLabelRule + ), + RecommendationsPolicy.tweetRules, + BaseTimelineHomePolicy.tweetRules, + ), + userRules = VisibilityPolicy.union( + RecommendationsPolicy.userRules, + BaseTimelineHomePolicy.userRules + ) ++ Seq( + AuthorBlocksViewerDropRule + ) + ) + +case object ExploreRecommendationsPolicy + extends VisibilityPolicy( + tweetRules = Seq( + DropOuterCommunityTweetsRule, + SearchBlacklistTweetLabelRule, + GoreAndViolenceTopicHighRecallTweetLabelRule, + NsfwHighRecallTweetLabelRule, + DropTweetsWithGeoRestrictedMediaRule, + TweetNsfwUserDropRule, + TweetNsfwAdminDropRule, + ViewerHasMatchingMutedKeywordForHomeTimelineRule, + ViewerHasMatchingMutedKeywordForNotificationsRule, + ) ++ VisibilityPolicy.union( + RecommendationsPolicy.tweetRules + ), + userRules = VisibilityPolicy.union( + RecommendationsPolicy.userRules + ) ++ Seq( + AuthorBlocksViewerDropRule, + ViewerMutesAuthorRule, + ViewerBlocksAuthorRule + ) + ) + +case object TombstoningPolicy + extends VisibilityPolicy( + tweetRules = Seq( + TombstoneIf.ViewerIsBlockedByAuthor, + TombstoneIf.AuthorIsProtected, + TombstoneIf.ReplyIsModeratedByRootAuthor, + TombstoneIf.AuthorIsSuspended, + TombstoneIf.AuthorIsDeactivated, + InterstitialIf.ViewerHardMutedAuthor + ) + ) + +case object TweetReplyNudgePolicy + extends VisibilityPolicy( + tweetRules = Seq( + SpamAllUsersTweetLabelRule, + PdnaAllUsersTweetLabelRule, + BounceAllUsersTweetLabelRule, + TweetNsfwAdminDropRule, + TweetNsfwUserDropRule, + NsfwHighRecallAllUsersTweetLabelDropRule, + NsfwHighPrecisionAllUsersTweetLabelDropRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelDropRule, + NsfwReportedHeuristicsAllUsersTweetLabelDropRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelDropRule, + NsfwCardImageAllUsersTweetLabelDropRule, + NsfwVideoAllUsersTweetLabelDropRule, + NsfwTextAllUsersTweetLabelDropRule, + ), + userRules = Seq( + DropNsfwUserAuthorRule, + DropNsfwAdminAuthorRule, + NsfwTextAllUsersDropRule + ) + ) + +case object HumanizationNudgePolicy + extends VisibilityPolicy( + userRules = UserTimelineRules.UserRules + ) + +case object TrendsRepresentativeTweetPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.union( + RecommendationsPolicy.tweetRules, + Seq( + AbusiveHighRecallTweetLabelRule, + BystanderAbusiveTweetLabelRule, + DuplicateContentTweetLabelDropRule, + LowQualityTweetLabelDropRule, + HighProactiveTosScoreTweetLabelDropRule, + NsfaHighRecallTweetLabelRule, + NsfwCardImageAllUsersTweetLabelDropRule, + NsfwHighPrecisionTweetLabelRule, + NsfwHighRecallAllUsersTweetLabelDropRule, + NsfwVideoTweetLabelDropRule, + NsfwTextTweetLabelDropRule, + PdnaAllUsersTweetLabelRule, + SearchBlacklistTweetLabelRule, + SpamHighRecallTweetLabelDropRule, + UntrustedUrlAllViewersTweetLabelRule, + DownrankSpamReplyAllViewersTweetLabelRule, + HighPSpammyScoreAllViewerDropRule, + DoNotAmplifyAllViewersDropRule, + SmyteSpamTweetLabelDropRule, + AuthorBlocksViewerDropRule, + ViewerBlocksAuthorRule, + ViewerMutesAuthorRule, + CopypastaSpamAllViewersTweetLabelRule, + ) + ), + userRules = VisibilityPolicy.union( + RecommendationsPolicy.userRules, + Seq( + AbusiveRule, + LowQualityRule, + ReadOnlyRule, + CompromisedRule, + RecommendationsBlacklistRule, + SpamHighRecallRule, + DuplicateContentRule, + NsfwHighPrecisionRule, + NsfwNearPerfectAuthorRule, + NsfwBannerImageRule, + NsfwAvatarImageRule, + EngagementSpammerRule, + EngagementSpammerHighRecallRule, + AbusiveHighRecallRule, + SearchBlacklistRule, + SearchNsfwTextRule, + NsfwHighRecallRule, + TsViolationRule, + DownrankSpamReplyAllViewersRule, + NsfwTextNonAuthorDropRule + ) + ) + ) + +case object AdsCampaignPolicy + extends VisibilityPolicy( + userRules = Seq(SuspendedAuthorRule), + tweetRules = VisibilityPolicy.baseTweetRules + ) + +case object AdsManagerPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ Seq( + AdsManagerDenyListAllUsersTweetLabelRule, + ) + ) + +case object AdsReportingDashboardPolicy + extends VisibilityPolicy( + tweetRules = AdsManagerPolicy.tweetRules, + userRules = AdsCampaignPolicy.userRules + ) + +case object BirdwatchNoteAuthorPolicy + extends VisibilityPolicy( + userRules = Seq( + SuspendedAuthorRule, + AuthorBlocksViewerDropRule, + ViewerBlocksAuthorRule, + ViewerMutesAuthorRule + ) + ) + +case object BirdwatchNoteTweetsTimelinePolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ + Seq( + MutedRetweetsRule, + AuthorBlocksViewerDropRule, + ViewerMutesAuthorRule, + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + ) ++ LimitedEngagementBaseRules.tweetRules + ) + +case object BirdwatchNeedsYourHelpNotificationsPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ + Seq( + AuthorBlocksViewerDropRule, + ViewerBlocksAuthorRule, + ViewerMutesAuthorRule, + ViewerHasMatchingMutedKeywordForHomeTimelineRule, + ViewerHasMatchingMutedKeywordForNotificationsRule, + ) + ) + +case object ForDevelopmentOnlyPolicy + extends VisibilityPolicy( + userRules = Seq.empty, + tweetRules = VisibilityPolicy.baseTweetRules + ) + +case object UserProfileHeaderPolicy + extends VisibilityPolicy( + userRules = Seq.empty, + tweetRules = Seq(DropAllRule) + ) + +case object UserScopedTimelinePolicy + extends VisibilityPolicy( + userRules = UserTimelineRules.UserRules, + tweetRules = Seq(DropAllRule) + ) + +case object TweetScopedTimelinePolicy + extends VisibilityPolicy( + userRules = UserTimelineRules.UserRules, + tweetRules = Seq.empty + ) + +case object SoftInterventionPivotPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules + ) + +case object CuratedTrendsRepresentativeTweetPolicy + extends VisibilityPolicy( + userRules = Seq( + SuspendedAuthorRule, + AuthorBlocksViewerDropRule, + ViewerBlocksAuthorRule, + ViewerMutesAndDoesNotFollowAuthorRule + ) + ) + +case object CommunitiesPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ + Seq( + RetweetDropRule, + AbusePolicyEpisodicTweetLabelDropRule, + EmergencyDropRule, + SafetyCrisisLevel4DropRule, + ReportedTweetInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + ) ++ LimitedEngagementBaseRules.tweetRules + ) + +case object TimelineHomeCommunitiesPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.union( + Seq( + DropAllAuthorRemovedCommunityTweetsRule, + DropAllHiddenCommunityTweetsRule, + ViewerHasMatchingMutedKeywordForHomeTimelineRule, + ), + VisibilityPolicy.baseQuotedTweetTombstoneRules, + CommunitiesPolicy.tweetRules, + ), + userRules = Seq( + ViewerMutesAuthorRule, + ViewerBlocksAuthorRule, + ) + ) + +case object TimelineHomePromotedHydrationPolicy + extends VisibilityPolicy( + tweetRules = Seq( + ViewerHasMatchingMutedKeywordForHomeTimelinePromotedTweetRule, + ViewerMutesAuthorHomeTimelinePromotedTweetRule, + ViewerBlocksAuthorHomeTimelinePromotedTweetRule + ) ++ TimelineHomeHydrationPolicy.tweetRules, + policyRuleParams = TimelineHomeHydrationPolicy.policyRuleParams + ) + +case object SpacesPolicy + extends VisibilityPolicy( + SpaceDoNotAmplifyAllUsersDropRule, + SpaceNsfwHighPrecisionNonFollowerDropRule), + userRules = Seq( + AuthorBlocksViewerDropRule + ) + ) + +case object SpacesSellerApplicationStatusPolicy + extends VisibilityPolicy( + userRules = Seq( + ViewerIsNotAuthorDropRule + ) + ) + +case object SpacesParticipantsPolicy + extends VisibilityPolicy( + tweetRules = Seq(DropAllRule), + userRules = Seq( + AuthorBlocksViewerDropRule, + SuspendedAuthorRule + ) + ) + +case object SpacesSharingPolicy + extends VisibilityPolicy( + tweetRules = TweetDetailPolicy.tweetRules, + userRules = Seq( + AuthorBlocksViewerDropRule, + ProtectedAuthorDropRule, + SuspendedAuthorRule + ), + policyRuleParams = TweetDetailPolicy.policyRuleParams + ) + +case object SpaceFleetlinePolicy + extends VisibilityPolicy( + spaceRules = Seq( + SpaceDoNotAmplifyNonFollowerDropRule, + SpaceCoordHarmfulActivityHighRecallNonFollowerDropRule, + SpaceUntrustedUrlNonFollowerDropRule, + SpaceMisleadingHighRecallNonFollowerDropRule, + SpaceNsfwHighPrecisionAllUsersInterstitialRule + ), + userRules = Seq( + TsViolationRule, + DoNotAmplifyNonFollowerRule, + NotGraduatedNonFollowerRule, + LikelyIvsLabelNonFollowerDropUserRule, + UserAbusiveNonFollowerDropRule + ) + ) + +case object SpaceNotificationsPolicy + extends VisibilityPolicy( + spaceRules = Seq( + SpaceHatefulHighRecallAllUsersDropRule, + SpaceViolenceHighRecallAllUsersDropRule, + SpaceDoNotAmplifyAllUsersDropRule, + SpaceCoordHarmfulActivityHighRecallAllUsersDropRule, + SpaceUntrustedUrlNonFollowerDropRule, + SpaceMisleadingHighRecallNonFollowerDropRule, + SpaceNsfwHighPrecisionAllUsersDropRule, + SpaceNsfwHighRecallAllUsersDropRule, + ViewerHasMatchingMutedKeywordInSpaceTitleForNotificationsRule + ), + userRules = Seq( + ViewerMutesAuthorRule, + ViewerBlocksAuthorRule, + AuthorBlocksViewerDropRule, + TsViolationRule, + DoNotAmplifyUserRule, + AbusiveRule, + SearchBlacklistRule, + SearchNsfwTextRule, + RecommendationsBlacklistRule, + NotGraduatedRule, + SpamHighRecallRule, + AbusiveHighRecallRule, + UserBlinkWorstAllUsersDropRule, + UserNsfwNearPerfectNonFollowerDropRule, + SpaceNsfwHighPrecisionNonFollowerDropRule, + UserNsfwAvatarImageNonFollowerDropRule, + UserNsfwBannerImageNonFollowerDropRule + ) + ) + +case object SpaceTweetAvatarHomeTimelinePolicy + extends VisibilityPolicy( + spaceRules = Seq( + SpaceDoNotAmplifyNonFollowerDropRule, + SpaceCoordHarmfulActivityHighRecallNonFollowerDropRule, + SpaceUntrustedUrlNonFollowerDropRule, + SpaceMisleadingHighRecallNonFollowerDropRule, + SpaceNsfwHighPrecisionAllUsersDropRule, + SpaceNsfwHighPrecisionAllUsersInterstitialRule + ), + userRules = Seq( + TsViolationRule, + DoNotAmplifyUserRule, + NotGraduatedNonFollowerRule, + AbusiveRule, + SearchBlacklistRule, + SearchNsfwTextRule, + RecommendationsBlacklistRule, + SpamHighRecallRule, + AbusiveHighRecallRule, + UserBlinkWorstAllUsersDropRule, + UserNsfwNearPerfectNonFollowerDropRule, + SpaceNsfwHighPrecisionNonFollowerDropRule, + UserNsfwAvatarImageNonFollowerDropRule, + UserNsfwBannerImageNonFollowerDropRule + ) + ) + +case object SpaceHomeTimelineUprankingPolicy + extends VisibilityPolicy( + spaceRules = Seq( + SpaceDoNotAmplifyNonFollowerDropRule, + SpaceCoordHarmfulActivityHighRecallNonFollowerDropRule, + SpaceUntrustedUrlNonFollowerDropRule, + SpaceMisleadingHighRecallNonFollowerDropRule, + SpaceNsfwHighPrecisionNonFollowerDropRule, + SpaceNsfwHighPrecisionSafeSearchNonFollowerDropRule, + SpaceNsfwHighRecallSafeSearchNonFollowerDropRule + ), + userRules = Seq( + TsViolationRule, + DoNotAmplifyUserRule, + NotGraduatedRule, + AbusiveRule, + SearchBlacklistRule, + SearchNsfwTextRule, + RecommendationsBlacklistRule, + SpamHighRecallRule, + AbusiveHighRecallRule, + UserBlinkWorstAllUsersDropRule, + UserNsfwNearPerfectNonFollowerDropRule, + UserNsfwAvatarImageNonFollowerDropRule, + UserNsfwBannerImageNonFollowerDropRule + ) + ) + +case object SpaceJoinScreenPolicy + extends VisibilityPolicy( + spaceRules = Seq( + SpaceNsfwHighPrecisionAllUsersInterstitialRule + ) + ) + +case object KitchenSinkDevelopmentPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules.diff( + Seq( + BounceTweetLabelRule, + DropExclusiveTweetContentRule, + DropTrustedFriendsTweetContentRule + ) + ) ++ Seq( + BounceTweetLabelTombstoneRule, + TombstoneExclusiveTweetContentRule, + TombstoneTrustedFriendsTweetContentRule) + ++ Seq( + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + ViewerReportsAuthorInterstitialRule, + ViewerMutesAuthorInterstitialRule, + ViewerBlocksAuthorInterstitialRule, + MutedKeywordForTweetRepliesInterstitialRule, + ReportedTweetInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + ExperimentalNudgeLabelRule, + ) ++ LimitedEngagementBaseRules.tweetRules, + userRules = Seq( + AuthorBlocksViewerDropRule, + ProtectedAuthorTombstoneRule, + SuspendedAuthorRule + ), + userUnavailableStateRules = Seq( + SuspendedUserUnavailableRetweetTombstoneRule, + DeactivatedUserUnavailableRetweetTombstoneRule, + OffBoardedUserUnavailableRetweetTombstoneRule, + ErasedUserUnavailableRetweetTombstoneRule, + ProtectedUserUnavailableRetweetTombstoneRule, + AuthorBlocksViewerUserUnavailableRetweetTombstoneRule, + ViewerBlocksAuthorUserUnavailableRetweetTombstoneRule, + ViewerMutesAuthorUserUnavailableRetweetTombstoneRule, + SuspendedUserUnavailableInnerQuotedTweetTombstoneRule, + DeactivatedUserUnavailableInnerQuotedTweetTombstoneRule, + OffBoardedUserUnavailableInnerQuotedTweetTombstoneRule, + ErasedUserUnavailableInnerQuotedTweetTombstoneRule, + ProtectedUserUnavailableInnerQuotedTweetTombstoneRule, + AuthorBlocksViewerUserUnavailableInnerQuotedTweetTombstoneRule, + SuspendedUserUnavailableTweetTombstoneRule, + DeactivatedUserUnavailableTweetTombstoneRule, + OffBoardedUserUnavailableTweetTombstoneRule, + ErasedUserUnavailableTweetTombstoneRule, + ProtectedUserUnavailableTweetTombstoneRule, + AuthorBlocksViewerUserUnavailableTweetTombstoneRule, + ViewerBlocksAuthorUserUnavailableInnerQuotedTweetInterstitialRule, + ViewerMutesAuthorUserUnavailableInnerQuotedTweetInterstitialRule + ), + deletedTweetRules = Seq( + TombstoneDeletedOuterTweetRule, + TombstoneBounceDeletedOuterTweetRule, + TombstoneDeletedQuotedTweetRule, + TombstoneBounceDeletedQuotedTweetRule + ), + mediaRules = VisibilityPolicy.baseMediaRules + ) + +case object CurationPolicyViolationsPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ Seq( + DoNotAmplifyAllViewersDropRule, + ), + userRules = Seq( + DoNotAmplifyUserRule, + TsViolationRule + ) + ) + +case object GraphqlDefaultPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ + Seq( + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + ) ++ LimitedEngagementBaseRules.tweetRules + ) + +case object GryphonDecksAndColumnsSharingPolicy + extends VisibilityPolicy( + userRules = Seq( + AuthorBlocksViewerDropRule, + ProtectedAuthorDropRule, + SuspendedAuthorRule + ), + tweetRules = Seq(DropAllRule) + ) + +case object UserSettingsPolicy + extends VisibilityPolicy( + userRules = Seq(ViewerIsNotAuthorDropRule), + tweetRules = Seq(DropAllRule) + ) + +case object BlockMuteUsersTimelinePolicy + extends VisibilityPolicy( + userRules = Seq(SuspendedAuthorRule), + tweetRules = Seq(DropAllRule) + ) + +case object TopicRecommendationsPolicy + extends VisibilityPolicy( + tweetRules = + Seq( + NsfwHighRecallTweetLabelRule, + NsfwTextTweetLabelTopicsDropRule + ) + ++ RecommendationsPolicy.tweetRules, + userRules = RecommendationsPolicy.userRules + ) + +case object RitoActionedTweetTimelinePolicy + extends VisibilityPolicy( + tweetRules = + VisibilityPolicy.baseTweetTombstoneRules + ++ Seq( + AuthorBlocksViewerTombstoneRule, + ProtectedAuthorTombstoneRule + ), + deletedTweetRules = Seq( + TombstoneDeletedOuterTweetRule, + TombstoneBounceDeletedOuterTweetRule, + TombstoneDeletedQuotedTweetRule, + TombstoneBounceDeletedQuotedTweetRule, + ), + ) + +case object EmbeddedTweetsPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetTombstoneRules + ++ Seq( + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + ) + ++ LimitedEngagementBaseRules.tweetRules, + deletedTweetRules = Seq( + TombstoneDeletedOuterTweetRule, + TombstoneBounceDeletedOuterTweetRule, + TombstoneDeletedQuotedTweetRule, + TombstoneBounceDeletedQuotedTweetRule, + ), + userUnavailableStateRules = Seq( + SuspendedUserUnavailableTweetTombstoneRule, + DeactivatedUserUnavailableTweetTombstoneRule, + OffBoardedUserUnavailableTweetTombstoneRule, + ErasedUserUnavailableTweetTombstoneRule, + ProtectedUserUnavailableTweetTombstoneRule, + AuthorBlocksViewerUserUnavailableInnerQuotedTweetTombstoneRule, + ) + ) + +case object EmbedTweetMarkupPolicy + extends VisibilityPolicy( + tweetRules = Seq(DropStaleTweetsRule) ++ + VisibilityPolicy.baseTweetTombstoneRules + ++ Seq( + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + ) + ++ LimitedEngagementBaseRules.tweetRules, + deletedTweetRules = Seq( + TombstoneDeletedOuterTweetRule, + TombstoneBounceDeletedOuterTweetRule, + TombstoneDeletedQuotedTweetRule, + TombstoneBounceDeletedQuotedTweetRule, + ), + ) + +case object ArticleTweetTimelinePolicy + extends VisibilityPolicy( + tweetRules = + VisibilityPolicy.baseTweetRules ++ + Seq( + ViewerHasMatchingMutedKeywordForHomeTimelineRule, + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + ) ++ LimitedEngagementBaseRules.tweetRules, + userRules = Seq( + AuthorBlocksViewerDropRule, + ViewerBlocksAuthorRule, + ViewerMutesAuthorRule, + ProtectedAuthorDropRule, + SuspendedAuthorRule + ) + ) + +case object ConversationFocalPrehydrationPolicy + extends VisibilityPolicy( + deletedTweetRules = Seq( + TombstoneBounceDeletedOuterTweetRule, + TombstoneBounceDeletedQuotedTweetRule, + ) + ) + +case object ConversationFocalTweetPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetTombstoneRules + ++ Seq( + DynamicProductAdDropTweetLabelRule, + AuthorBlocksViewerTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.AdultMediaNsfwHighPrecisionTweetLabelDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.ViolentMediaGoreAndViolenceHighPrecisionDropSettingLeveTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.AdultMediaNsfwReportedHeuristicsTweetLabelDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.ViolentMediaGoreAndViolenceReportedHeuristicsDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.AdultMediaNsfwCardImageTweetLabelDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.OtherSensitiveMediaNsfwUserTweetFlagDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.OtherSensitiveMediaNsfwAdminTweetFlagDropSettingLevelTombstoneRule, + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + ReportedTweetInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwHighPrecisionTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceHighPrecisionInterstitialRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwReportedHeuristicsTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceReportedHeuristicsInterstitialRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwCardImageTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwUserTweetFlagInterstitialRule, + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwAdminTweetFlagInterstitialRule, + GoreAndViolenceHighPrecisionAvoidAllUsersTweetLabelRule, + NsfwReportedHeuristicsAvoidAdPlacementAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAvoidAdPlacementAllUsersTweetLabelRule, + NsfwCardImageAvoidAdPlacementAllUsersTweetLabelRule, + MutedKeywordForQuotedTweetTweetDetailInterstitialRule, + ViewerMutesAuthorInnerQuotedTweetInterstitialRule, + ViewerBlocksAuthorInnerQuotedTweetInterstitialRule, + ) + ++ LimitedEngagementBaseRules.tweetRules + ++ ConversationsAdAvoidanceRules.tweetRules, + deletedTweetRules = Seq( + TombstoneBounceDeletedOuterTweetRule, + TombstoneDeletedQuotedTweetRule, + TombstoneBounceDeletedQuotedTweetRule, + ), + userUnavailableStateRules = Seq( + SuspendedUserUnavailableTweetTombstoneRule, + DeactivatedUserUnavailableTweetTombstoneRule, + OffBoardedUserUnavailableTweetTombstoneRule, + ErasedUserUnavailableTweetTombstoneRule, + ProtectedUserUnavailableTweetTombstoneRule, + AuthorBlocksViewerUserUnavailableInnerQuotedTweetTombstoneRule, + UserUnavailableTweetTombstoneRule, + ViewerBlocksAuthorUserUnavailableInnerQuotedTweetInterstitialRule, + ViewerMutesAuthorUserUnavailableInnerQuotedTweetInterstitialRule + ), + policyRuleParams = ConversationsAdAvoidanceRules.policyRuleParams + ++ SensitiveMediaSettingsConversationBaseRules.policyRuleParams + ) + +case object ConversationReplyPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetTombstoneRules + ++ Seq( + LowQualityTweetLabelTombstoneRule, + SpamHighRecallTweetLabelTombstoneRule, + DuplicateContentTweetLabelTombstoneRule, + DeciderableSpamHighRecallAuthorLabelTombstoneRule, + SmyteSpamTweetLabelTombstoneRule, + AuthorBlocksViewerTombstoneRule, + ToxicityReplyFilterRule, + DynamicProductAdDropTweetLabelRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.AdultMediaNsfwHighPrecisionTweetLabelDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.ViolentMediaGoreAndViolenceHighPrecisionDropSettingLeveTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.AdultMediaNsfwReportedHeuristicsTweetLabelDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.ViolentMediaGoreAndViolenceReportedHeuristicsDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.AdultMediaNsfwCardImageTweetLabelDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.OtherSensitiveMediaNsfwUserTweetFlagDropSettingLevelTombstoneRule, + SensitiveMediaTweetDropSettingLevelTombstoneRules.OtherSensitiveMediaNsfwAdminTweetFlagDropSettingLevelTombstoneRule, + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + MutedKeywordForTweetRepliesInterstitialRule, + ReportedTweetInterstitialRule, + ViewerBlocksAuthorInterstitialRule, + ViewerMutesAuthorInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwHighPrecisionTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceHighPrecisionInterstitialRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwReportedHeuristicsTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.ViolentMediaGoreAndViolenceReportedHeuristicsInterstitialRule, + SensitiveMediaTweetInterstitialRules.AdultMediaNsfwCardImageTweetLabelInterstitialRule, + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwUserTweetFlagInterstitialRule, + SensitiveMediaTweetInterstitialRules.OtherSensitiveMediaNsfwAdminTweetFlagInterstitialRule, + GoreAndViolenceHighPrecisionAvoidAllUsersTweetLabelRule, + NsfwReportedHeuristicsAvoidAdPlacementAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAvoidAdPlacementAllUsersTweetLabelRule, + NsfwCardImageAvoidAdPlacementAllUsersTweetLabelRule, + ) + ++ LimitedEngagementBaseRules.tweetRules + ++ ConversationsAdAvoidanceRules.tweetRules, + userRules = Seq( + LowQualityRule, + ReadOnlyRule, + LowQualityHighRecallRule, + CompromisedRule, + DeciderableSpamHighRecallRule + ), + deletedTweetRules = Seq( + TombstoneDeletedOuterTweetRule, + TombstoneBounceDeletedOuterTweetRule, + TombstoneDeletedQuotedTweetRule, + TombstoneBounceDeletedQuotedTweetRule, + ), + userUnavailableStateRules = Seq( + SuspendedUserUnavailableTweetTombstoneRule, + DeactivatedUserUnavailableTweetTombstoneRule, + OffBoardedUserUnavailableTweetTombstoneRule, + ErasedUserUnavailableTweetTombstoneRule, + ProtectedUserUnavailableTweetTombstoneRule, + AuthorBlocksViewerUserUnavailableInnerQuotedTweetTombstoneRule, + UserUnavailableTweetTombstoneRule, + ViewerBlocksAuthorUserUnavailableInnerQuotedTweetInterstitialRule, + ViewerMutesAuthorUserUnavailableInnerQuotedTweetInterstitialRule + ), + policyRuleParams = ConversationsAdAvoidanceRules.policyRuleParams + ++ SensitiveMediaSettingsConversationBaseRules.policyRuleParams + ) + +case object AdsBusinessSettingsPolicy + extends VisibilityPolicy( + tweetRules = Seq(DropAllRule) + ) + +case object UserMilestoneRecommendationPolicy + extends VisibilityPolicy( + userRules = RecommendationsPolicy.userRules ++ Seq( + ) + ) + +case object TrustedFriendsUserListPolicy + extends VisibilityPolicy( + tweetRules = Seq(DropAllRule), + userRules = Seq( + ViewerBlocksAuthorRule + ) + ) + +case object QuickPromoteTweetEligibilityPolicy + extends VisibilityPolicy( + tweetRules = TweetDetailPolicy.tweetRules, + userRules = UserTimelineRules.UserRules, + policyRuleParams = TweetDetailPolicy.policyRuleParams + ) + +case object ReportCenterPolicy + extends VisibilityPolicy( + tweetRules = ConversationFocalTweetPolicy.tweetRules.diff( + ConversationsAdAvoidanceRules.tweetRules + ), + deletedTweetRules = Seq( + TombstoneBounceDeletedOuterTweetRule, + TombstoneDeletedQuotedTweetRule, + TombstoneBounceDeletedQuotedTweetRule, + TombstoneDeletedOuterTweetRule, + ), + userUnavailableStateRules = Seq( + SuspendedUserUnavailableTweetTombstoneRule, + DeactivatedUserUnavailableTweetTombstoneRule, + OffBoardedUserUnavailableTweetTombstoneRule, + ErasedUserUnavailableTweetTombstoneRule, + ProtectedUserUnavailableTweetTombstoneRule, + AuthorBlocksViewerUserUnavailableInnerQuotedTweetTombstoneRule, + UserUnavailableTweetTombstoneRule, + ViewerBlocksAuthorUserUnavailableInnerQuotedTweetInterstitialRule, + ViewerMutesAuthorUserUnavailableInnerQuotedTweetInterstitialRule + ), + policyRuleParams = ConversationFocalTweetPolicy.policyRuleParams + ) + +case object ConversationInjectedTweetPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetRules ++ + Seq( + AbusePolicyEpisodicTweetLabelInterstitialRule, + EmergencyDynamicInterstitialRule, + NsfwHighPrecisionInterstitialAllUsersTweetLabelRule, + GoreAndViolenceHighPrecisionAllUsersTweetLabelRule, + NsfwReportedHeuristicsAllUsersTweetLabelRule, + GoreAndViolenceReportedHeuristicsAllUsersTweetLabelRule, + NsfwCardImageAllUsersTweetLabelRule, + ) ++ + LimitedEngagementBaseRules.tweetRules ++ Seq( + SkipTweetDetailLimitedEngagementTweetLabelRule + ) + ) + +case object EditHistoryTimelinePolicy + extends VisibilityPolicy( + tweetRules = ConversationReplyPolicy.tweetRules, + policyRuleParams = ConversationReplyPolicy.policyRuleParams, + deletedTweetRules = ConversationReplyPolicy.deletedTweetRules, + userUnavailableStateRules = ConversationReplyPolicy.userUnavailableStateRules) + +case object UserSelfViewOnlyPolicy + extends VisibilityPolicy( + userRules = Seq(ViewerIsNotAuthorDropRule), + tweetRules = Seq(DropAllRule) + ) + +case object TwitterArticleComposePolicy + extends VisibilityPolicy( + twitterArticleRules = Seq( + ViewerIsNotAuthorDropRule + ) + ) + +case object TwitterArticleProfileTabPolicy + extends VisibilityPolicy( + twitterArticleRules = Seq( + AuthorBlocksViewerDropRule + ) + ) + +case object TwitterArticleReadPolicy + extends VisibilityPolicy( + twitterArticleRules = Seq( + AuthorBlocksViewerDropRule, + ) + ) + +case object ContentControlToolInstallPolicy + extends VisibilityPolicy( + userRules = UserProfileHeaderPolicy.userRules, + tweetRules = UserProfileHeaderPolicy.tweetRules + ) + +case object TimelineProfileSpacesPolicy + extends VisibilityPolicy( + userRules = UserProfileHeaderPolicy.userRules, + tweetRules = UserProfileHeaderPolicy.tweetRules + ) + +case object TimelineFavoritesSelfViewPolicy + extends VisibilityPolicy( + tweetRules = TimelineFavoritesPolicy.tweetRules.diff(Seq(DropStaleTweetsRule)), + policyRuleParams = TimelineFavoritesPolicy.policyRuleParams, + deletedTweetRules = TimelineFavoritesPolicy.deletedTweetRules, + userUnavailableStateRules = TimelineFavoritesPolicy.userUnavailableStateRules + ) + +case object BaseQigPolicy + extends VisibilityPolicy( + tweetRules = Seq( + AbusePolicyEpisodicTweetLabelDropRule, + AutomationTweetLabelRule, + DoNotAmplifyDropRule, + DownrankSpamReplyTweetLabelRule, + DuplicateContentTweetLabelDropRule, + DuplicateMentionTweetLabelRule, + NsfwHighPrecisionTweetLabelRule, + GoreAndViolenceHighPrecisionTweetLabelRule, + GoreAndViolenceReportedHeuristicsTweetLabelRule, + LikelyIvsLabelNonFollowerDropUserRule, + NsfwCardImageTweetLabelRule, + NsfwHighRecallTweetLabelRule, + NsfwReportedHeuristicsTweetLabelRule, + NsfwTextTweetLabelDropRule, + NsfwVideoTweetLabelDropRule, + PdnaTweetLabelRule, + SafetyCrisisLevel3DropRule, + SafetyCrisisLevel4DropRule, + SearchBlacklistHighRecallTweetLabelDropRule, + SearchBlacklistTweetLabelRule, + SmyteSpamTweetLabelDropRule, + SpamHighRecallTweetLabelDropRule, + ), + userRules = Seq( + DuplicateContentRule, + EngagementSpammerHighRecallRule, + EngagementSpammerRule, + NsfwAvatarImageRule, + NsfwBannerImageRule, + NsfwHighPrecisionRule, + NsfwHighRecallRule, + NsfwSensitiveRule, + ReadOnlyRule, + RecommendationsBlacklistRule, + SearchBlacklistRule, + SpamHighRecallRule + )) + +case object NotificationsQigPolicy + extends VisibilityPolicy( + tweetRules = BaseQigPolicy.tweetRules ++ Seq( + DropAllCommunityTweetsRule, + DropNsfwAdminAuthorViewerOptInFilteringOnSearchRule, + HighProactiveTosScoreTweetLabelDropSearchRule, + LowQualityTweetLabelDropRule, + NsfwHighPrecisionRule, + NsfwHighRecallRule, + NsfwNearPerfectAuthorRule, + NsfwSensitiveRule, + ), + userRules = BaseQigPolicy.userRules ++ Seq( + AbusiveRule, + LowQualityRule, + CompromisedRule, + ViewerBlocksAuthorViewerOptInBlockingOnSearchRule, + ViewerMutesAuthorViewerOptInBlockingOnSearchRule, + DropNsfwAdminAuthorViewerOptInFilteringOnSearchRule, + NsfwNearPerfectAuthorRule + ) + ) + +case object ShoppingManagerSpyModePolicy + extends VisibilityPolicy( + tweetRules = Seq( + DropAllRule + ), + userRules = Seq( + SuspendedAuthorRule, + DeactivatedAuthorRule, + ErasedAuthorRule, + OffboardedAuthorRule + ) + ) + +case object ZipbirdConsumerArchivesPolicy + extends VisibilityPolicy( + tweetRules = VisibilityPolicy.baseTweetTombstoneRules, + userRules = Seq( + AuthorBlocksViewerDropRule, + ProtectedAuthorDropRule, + SuspendedAuthorRule, + ), + userUnavailableStateRules = Seq( + AuthorBlocksViewerUserUnavailableTweetTombstoneRule, + ProtectedUserUnavailableTweetTombstoneRule, + SuspendedUserUnavailableTweetTombstoneRule, + ), + deletedTweetRules = Seq( + TombstoneDeletedTweetRule, + TombstoneBounceDeletedTweetRule, + ) + ) + +case class MixedVisibilityPolicy( + originalPolicy: VisibilityPolicy, + additionalTweetRules: Seq[Rule]) + extends VisibilityPolicy( + tweetRules = (additionalTweetRules ++ originalPolicy.tweetRules) + .sortWith(_.actionBuilder.actionSeverity > _.actionBuilder.actionSeverity), + userRules = originalPolicy.userRules, + cardRules = originalPolicy.cardRules, + quotedTweetRules = originalPolicy.quotedTweetRules, + dmRules = originalPolicy.dmRules, + dmConversationRules = originalPolicy.dmConversationRules, + dmEventRules = originalPolicy.dmEventRules, + spaceRules = originalPolicy.spaceRules, + userUnavailableStateRules = originalPolicy.userUnavailableStateRules, + twitterArticleRules = originalPolicy.twitterArticleRules, + deletedTweetRules = originalPolicy.deletedTweetRules, + mediaRules = originalPolicy.mediaRules, + communityRules = originalPolicy.communityRules, + policyRuleParams = originalPolicy.policyRuleParams + ) + +case object TweetAwardPolicy + extends VisibilityPolicy( + userRules = Seq.empty, + tweetRules = + VisibilityPolicy.baseTweetRules ++ Seq( + EmergencyDropRule, + NsfwHighPrecisionTweetLabelRule, + NsfwHighRecallTweetLabelRule, + NsfwReportedHeuristicsTweetLabelRule, + NsfwCardImageTweetLabelRule, + NsfwVideoTweetLabelDropRule, + NsfwTextTweetLabelDropRule, + GoreAndViolenceHighPrecisionTweetLabelRule, + GoreAndViolenceReportedHeuristicsTweetLabelRule, + GoreAndViolenceTweetLabelRule, + AbusePolicyEpisodicTweetLabelDropRule, + AbusiveTweetLabelRule, + BystanderAbusiveTweetLabelRule + ) + ) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/generators/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/rules/generators/BUILD new file mode 100644 index 000000000..424c99ac4 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/generators/BUILD @@ -0,0 +1,37 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "3rdparty/jvm/com/google/guava", + "3rdparty/jvm/com/squareup/okhttp:okhttp3", + "abdecider/src/main/scala", + "configapi/configapi-core", + "decider/src/main/scala", + "scribelib/marshallers/src/main/scala/com/twitter/scribelib/marshallers", + "servo/decider/src/main/scala", + "snowflake/src/main/scala/com/twitter/snowflake/id", + "src/scala/com/twitter/takedown/util", + "src/thrift/com/twitter/content-health/sensitivemediasettings:sensitivemediasettings-scala", + "src/thrift/com/twitter/gizmoduck:user-thrift-scala", + "src/thrift/com/twitter/search/common:constants-scala", + "src/thrift/com/twitter/spam/rtf:safety-level-scala", + "src/thrift/com/twitter/spam/rtf:safety-result-scala", + "src/thrift/com/twitter/tweetypie:tweet-scala", + "stitch/stitch-core/src/main/scala/com/twitter/stitch", + "visibility/common/src/main/scala/com/twitter/visibility/common/actions", + "visibility/common/src/main/scala/com/twitter/visibility/common/actions/converter/scala", + "visibility/common/src/main/thrift/com/twitter/visibility:action-scala", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/configs", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/params", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/models", + "visibility/lib/src/main/scala/com/twitter/visibility/rules", + "visibility/lib/src/main/scala/com/twitter/visibility/util", + "visibility/lib/src/main/thrift/com/twitter/visibility/logging:vf-logging-scala", + "visibility/lib/src/main/thrift/com/twitter/visibility/strato:vf-strato-scala", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/generators/RuleGenerator.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/generators/RuleGenerator.scala new file mode 100644 index 000000000..262134636 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/generators/RuleGenerator.scala @@ -0,0 +1,8 @@ +package com.twitter.visibility.rules.generators + +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.rules.Rule + +trait RuleGenerator { + def rulesForSurface(safetyLevel: SafetyLevel): Seq[Rule] +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/generators/TweetRuleGenerator.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/generators/TweetRuleGenerator.scala new file mode 100644 index 000000000..6bdb965a1 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/generators/TweetRuleGenerator.scala @@ -0,0 +1,321 @@ +package com.twitter.visibility.rules.generators + +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.SafetyLevelGroup +import com.twitter.visibility.models.ViolationLevel +import com.twitter.visibility.rules.FreedomOfSpeechNotReachActions +import com.twitter.visibility.rules.FreedomOfSpeechNotReachRules +import com.twitter.visibility.rules.Rule +import com.twitter.visibility.rules.generators.TweetRuleGenerator.violationLevelPolicies + +object TweetRuleGenerator { + private val level3LimitedActions: Seq[String] = Seq( + "like", + "reply", + "retweet", + "quote_tweet", + "share_tweet_via", + "add_to_bookmarks", + "pin_to_profile", + "copy_link", + "send_via_dm") + private val violationLevelPolicies: Map[ + ViolationLevel, + Map[UserType, TweetVisibilityPolicy] + ] = Map( + ViolationLevel.Level1 -> Map( + UserType.Follower -> TweetVisibilityPolicy + .builder() + .addGlobalRule(FreedomOfSpeechNotReachActions.SoftInterventionAvoidAction()) + .addSafetyLevelGroupRule( + SafetyLevelGroup.Notifications, + FreedomOfSpeechNotReachActions.DropAction()) + .addSafetyLevelGroupRule( + SafetyLevelGroup.Recommendations, + FreedomOfSpeechNotReachActions.DropAction()) + .addSafetyLevelGroupRule( + SafetyLevelGroup.Search, + FreedomOfSpeechNotReachActions.DropAction()) + .addSafetyLevelGroupRule( + SafetyLevelGroup.TopicRecommendations, + FreedomOfSpeechNotReachActions.DropAction()) + .addSafetyLevelRule( + SafetyLevel.TimelineHomeRecommendations, + FreedomOfSpeechNotReachActions.DropAction()) + .addSafetyLevelRule( + SafetyLevel.TrendsRepresentativeTweet, + FreedomOfSpeechNotReachActions.DropAction()) + .build, + UserType.Author -> TweetVisibilityPolicy + .builder() + .addGlobalRule(FreedomOfSpeechNotReachActions.AppealableAction()) + .build, + UserType.Other -> TweetVisibilityPolicy + .builder() + .addGlobalRule(FreedomOfSpeechNotReachActions.SoftInterventionAvoidAction()) + .addSafetyLevelGroupRule( + SafetyLevelGroup.Notifications, + FreedomOfSpeechNotReachActions.DropAction()) + .addSafetyLevelGroupRule( + SafetyLevelGroup.Recommendations, + FreedomOfSpeechNotReachActions.DropAction()) + .addSafetyLevelGroupRule( + SafetyLevelGroup.TimelineHome, + FreedomOfSpeechNotReachActions.DropAction()) + .addSafetyLevelGroupRule( + SafetyLevelGroup.Search, + FreedomOfSpeechNotReachActions.DropAction()) + .addSafetyLevelGroupRule( + SafetyLevelGroup.TopicRecommendations, + FreedomOfSpeechNotReachActions.DropAction()) + .addSafetyLevelRule( + SafetyLevel.TrendsRepresentativeTweet, + FreedomOfSpeechNotReachActions.DropAction()) + .addSafetyLevelRule( + SafetyLevel.ConversationReply, + FreedomOfSpeechNotReachActions.SoftInterventionAvoidAbusiveQualityReplyAction()) + .build, + ), + ViolationLevel.Level3 -> Map( + UserType.Follower -> TweetVisibilityPolicy + .builder() + .addGlobalRule(FreedomOfSpeechNotReachActions.DropAction()) + .addSafetyLevelGroupRule( + SafetyLevelGroup.TimelineProfile, + FreedomOfSpeechNotReachActions.SoftInterventionAvoidLimitedEngagementsAction( + limitedActionStrings = Some(level3LimitedActions)) + ) + .addSafetyLevelGroupRule( + SafetyLevelGroup.TweetDetails, + FreedomOfSpeechNotReachActions.SoftInterventionAvoidLimitedEngagementsAction( + limitedActionStrings = Some(level3LimitedActions)) + ) + .addSafetyLevelRule( + SafetyLevel.ConversationReply, + FreedomOfSpeechNotReachActions.SoftInterventionAvoidLimitedEngagementsAction( + limitedActionStrings = Some(level3LimitedActions)) + ) + .addSafetyLevelRule( + SafetyLevel.ConversationFocalTweet, + FreedomOfSpeechNotReachActions.SoftInterventionAvoidLimitedEngagementsAction( + limitedActionStrings = Some(level3LimitedActions)) + ) + .addSafetyLevelRule( + SafetyLevel.TimelineMedia, + FreedomOfSpeechNotReachActions + .SoftInterventionAvoidLimitedEngagementsAction(limitedActionStrings = + Some(level3LimitedActions)) + ) + .addSafetyLevelRule( + SafetyLevel.ProfileMixerMedia, + FreedomOfSpeechNotReachActions + .SoftInterventionAvoidLimitedEngagementsAction(limitedActionStrings = + Some(level3LimitedActions)) + ) + .addSafetyLevelRule( + SafetyLevel.TimelineFavorites, + FreedomOfSpeechNotReachActions + .SoftInterventionAvoidLimitedEngagementsAction(limitedActionStrings = + Some(level3LimitedActions)) + ) + .addSafetyLevelRule( + SafetyLevel.ProfileMixerFavorites, + FreedomOfSpeechNotReachActions + .SoftInterventionAvoidLimitedEngagementsAction(limitedActionStrings = + Some(level3LimitedActions)) + ) + .build, + UserType.Author -> TweetVisibilityPolicy + .builder() + .addGlobalRule( + FreedomOfSpeechNotReachActions.AppealableAvoidLimitedEngagementsAction( + limitedActionStrings = Some(level3LimitedActions)) + ) + .build, + UserType.Other -> TweetVisibilityPolicy + .builder() + .addGlobalRule(FreedomOfSpeechNotReachActions.DropAction()) + .addSafetyLevelGroupRule( + SafetyLevelGroup.TimelineProfile, + FreedomOfSpeechNotReachActions + .InterstitialLimitedEngagementsAvoidAction(limitedActionStrings = + Some(level3LimitedActions)) + ) + .addSafetyLevelGroupRule( + SafetyLevelGroup.TweetDetails, + FreedomOfSpeechNotReachActions + .InterstitialLimitedEngagementsAvoidAction(limitedActionStrings = + Some(level3LimitedActions)) + ) + .addSafetyLevelRule( + SafetyLevel.ConversationReply, + FreedomOfSpeechNotReachActions + .InterstitialLimitedEngagementsAvoidAction(limitedActionStrings = + Some(level3LimitedActions)) + ) + .addSafetyLevelRule( + SafetyLevel.ConversationFocalTweet, + FreedomOfSpeechNotReachActions + .InterstitialLimitedEngagementsAvoidAction(limitedActionStrings = + Some(level3LimitedActions)) + ) + .addSafetyLevelRule( + SafetyLevel.TimelineMedia, + FreedomOfSpeechNotReachActions + .InterstitialLimitedEngagementsAvoidAction(limitedActionStrings = + Some(level3LimitedActions)) + ) + .addSafetyLevelRule( + SafetyLevel.ProfileMixerMedia, + FreedomOfSpeechNotReachActions + .InterstitialLimitedEngagementsAvoidAction(limitedActionStrings = + Some(level3LimitedActions)) + ) + .addSafetyLevelRule( + SafetyLevel.TimelineFavorites, + FreedomOfSpeechNotReachActions + .InterstitialLimitedEngagementsAvoidAction(limitedActionStrings = + Some(level3LimitedActions)) + ) + .addSafetyLevelRule( + SafetyLevel.ProfileMixerFavorites, + FreedomOfSpeechNotReachActions + .InterstitialLimitedEngagementsAvoidAction(limitedActionStrings = + Some(level3LimitedActions)) + ) + .build, + ), + ) +} +sealed trait UserType +object UserType { + case object Author extends UserType + + case object Follower extends UserType + + case object Other extends UserType +} +class TweetRuleGenerator extends RuleGenerator { + + private[rules] val tweetRulesForSurface: Map[SafetyLevel, Seq[Rule]] = generateTweetPolicies() + + private[rules] def getViolationLevelPolicies = violationLevelPolicies + + override def rulesForSurface(safetyLevel: SafetyLevel): Seq[Rule] = + tweetRulesForSurface.getOrElse(safetyLevel, Seq()) + + private def generateRulesForPolicy( + violationLevel: ViolationLevel, + userType: UserType, + tweetVisibilityPolicy: TweetVisibilityPolicy + ): Seq[(SafetyLevel, Rule)] = { + tweetVisibilityPolicy + .getRules() + .map { + case (safetyLevel, actionBuilder) => + safetyLevel -> (userType match { + case UserType.Author => + FreedomOfSpeechNotReachRules.ViewerIsAuthorAndTweetHasViolationOfLevel( + violationLevel = violationLevel, + actionBuilder = actionBuilder.withViolationLevel(violationLevel = violationLevel)) + case UserType.Follower => + FreedomOfSpeechNotReachRules.ViewerIsFollowerAndTweetHasViolationOfLevel( + violationLevel = violationLevel, + actionBuilder = actionBuilder.withViolationLevel(violationLevel = violationLevel)) + case UserType.Other => + FreedomOfSpeechNotReachRules.ViewerIsNonFollowerNonAuthorAndTweetHasViolationOfLevel( + violationLevel = violationLevel, + actionBuilder = actionBuilder.withViolationLevel(violationLevel = violationLevel)) + }) + }.toSeq + } + + private def generatePoliciesForViolationLevel( + violationLevel: ViolationLevel + ): Seq[(SafetyLevel, Rule)] = { + getViolationLevelPolicies + .get(violationLevel).map { policiesPerUserType => + Seq(UserType.Author, UserType.Follower, UserType.Other).foldLeft( + List.empty[(UserType, SafetyLevel, Rule)]) { + case (rulesForAllUserTypes, userType) => + rulesForAllUserTypes ++ generateRulesForPolicy( + violationLevel = violationLevel, + userType = userType, + tweetVisibilityPolicy = policiesPerUserType(userType)).map { + case (safetyLevel, rule) => (userType, safetyLevel, rule) + } + } + } + .map(policy => optimizePolicy(policy = policy, violationLevel = violationLevel)) + .getOrElse(List()) + } + + private def injectFallbackRule(rules: Seq[Rule]): Seq[Rule] = { + rules :+ FreedomOfSpeechNotReachRules.TweetHasViolationOfAnyLevelFallbackDropRule + } + + private def optimizePolicy( + policy: Seq[(UserType, SafetyLevel, Rule)], + violationLevel: ViolationLevel + ): Seq[(SafetyLevel, Rule)] = { + val policiesByUserType = policy.groupBy { case (userType, _, _) => userType }.map { + case (userType, aggregated) => + (userType, aggregated.map { case (_, safetyLevel, rules) => (safetyLevel, rules) }) + } + val followerPolicies = aggregateRulesBySafetyLevel( + policiesByUserType.getOrElse(UserType.Follower, Seq())) + val otherPolicies = aggregateRulesBySafetyLevel( + policiesByUserType.getOrElse(UserType.Other, Seq())) + policiesByUserType(UserType.Author) ++ + followerPolicies.collect { + case (safetyLevel, rule) if !otherPolicies.contains(safetyLevel) => + (safetyLevel, rule) + } ++ + otherPolicies.collect { + case (safetyLevel, rule) if !followerPolicies.contains(safetyLevel) => + (safetyLevel, rule) + } ++ + followerPolicies.keySet + .intersect(otherPolicies.keySet).foldLeft(List.empty[(SafetyLevel, Rule)]) { + case (aggr, safetyLevel) + if followerPolicies(safetyLevel).actionBuilder == otherPolicies( + safetyLevel).actionBuilder => + ( + safetyLevel, + FreedomOfSpeechNotReachRules.ViewerIsNonAuthorAndTweetHasViolationOfLevel( + violationLevel = violationLevel, + actionBuilder = followerPolicies(safetyLevel).actionBuilder + )) :: aggr + case (aggr, safetyLevel) => + (safetyLevel, followerPolicies(safetyLevel)) :: + (safetyLevel, otherPolicies(safetyLevel)) :: aggr + } + } + + private def aggregateRulesBySafetyLevel( + policy: Seq[(SafetyLevel, Rule)] + ): Map[SafetyLevel, Rule] = { + policy + .groupBy { + case (safetyLevel, _) => safetyLevel + }.map { + case (safetyLevel, Seq((_, rule))) => + (safetyLevel, rule) + case _ => throw new Exception("Policy optimization failure") + } + } + + private def generateTweetPolicies(): Map[SafetyLevel, Seq[Rule]] = { + Seq(ViolationLevel.Level4, ViolationLevel.Level3, ViolationLevel.Level2, ViolationLevel.Level1) + .foldLeft(List.empty[(SafetyLevel, Rule)]) { + case (rulesForAllViolationLevels, violationLevel) => + rulesForAllViolationLevels ++ + generatePoliciesForViolationLevel(violationLevel) + } + .groupBy { case (safetyLevel, _) => safetyLevel } + .map { + case (safetyLevel, list) => + (safetyLevel, injectFallbackRule(list.map { case (_, rule) => rule })) + } + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/generators/TweetVisibilityPolicy.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/generators/TweetVisibilityPolicy.scala new file mode 100644 index 000000000..1b9de7a1c --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/generators/TweetVisibilityPolicy.scala @@ -0,0 +1,74 @@ +package com.twitter.visibility.rules.generators + +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.models.SafetyLevelGroup +import com.twitter.visibility.rules.Action +import com.twitter.visibility.rules.FreedomOfSpeechNotReachActions.FreedomOfSpeechNotReachActionBuilder + +class TweetVisibilityPolicy( + rules: Map[SafetyLevel, FreedomOfSpeechNotReachActionBuilder[_ <: Action]] = Map()) { + def getRules(): Map[SafetyLevel, FreedomOfSpeechNotReachActionBuilder[_ <: Action]] = rules +} + +object TweetVisibilityPolicy { + private[generators] val allApplicableSurfaces = + SafetyLevel.List.toSet -- + SafetyLevelGroup.Special.levels -- + Set( + SafetyLevel.SearchPeopleTypeahead, + SafetyLevel.UserProfileHeader, + SafetyLevel.UserScopedTimeline, + SafetyLevel.SpacesParticipants, + SafetyLevel.GryphonDecksAndColumns, + SafetyLevel.UserSettings, + SafetyLevel.BlockMuteUsersTimeline, + SafetyLevel.AdsBusinessSettings, + SafetyLevel.TrustedFriendsUserList, + SafetyLevel.UserSelfViewOnly, + SafetyLevel.ShoppingManagerSpyMode, + ) + + def builder(): TweetVisibilityPolicyBuilder = TweetVisibilityPolicyBuilder() +} + +case class TweetVisibilityPolicyBuilder( + rules: Map[SafetyLevel, FreedomOfSpeechNotReachActionBuilder[_ <: Action]] = Map()) { + + def addGlobalRule[T <: Action]( + actionBuilder: FreedomOfSpeechNotReachActionBuilder[T] + ): TweetVisibilityPolicyBuilder = + copy(rules = + rules ++ TweetVisibilityPolicy.allApplicableSurfaces.map(_ -> actionBuilder)) + + def addSafetyLevelRule[T <: Action]( + safetyLevel: SafetyLevel, + actionBuilder: FreedomOfSpeechNotReachActionBuilder[T] + ): TweetVisibilityPolicyBuilder = { + if (TweetVisibilityPolicy.allApplicableSurfaces.contains(safetyLevel)) { + copy(rules = rules ++ Map(safetyLevel -> actionBuilder)) + } else { + this + } + } + + def addSafetyLevelGroupRule[T <: Action]( + group: SafetyLevelGroup, + actionBuilder: FreedomOfSpeechNotReachActionBuilder[T] + ): TweetVisibilityPolicyBuilder = + copy(rules = + rules ++ group.levels.collect { + case safetyLevel if TweetVisibilityPolicy.allApplicableSurfaces.contains(safetyLevel) => + safetyLevel -> actionBuilder + }) + + def addRuleForAllRemainingSafetyLevels[T <: Action]( + actionBuilder: FreedomOfSpeechNotReachActionBuilder[T] + ): TweetVisibilityPolicyBuilder = + copy(rules = + rules ++ (TweetVisibilityPolicy.allApplicableSurfaces -- rules.keySet) + .map(_ -> actionBuilder).toMap) + + def build: TweetVisibilityPolicy = { + new TweetVisibilityPolicy(rules) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/package.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/package.scala new file mode 100644 index 000000000..2b4019f46 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/package.scala @@ -0,0 +1,5 @@ +package com.twitter.visibility + +package object rules { + type LabelTypeId = Short +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/providers/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/rules/providers/BUILD new file mode 100644 index 000000000..a65d83f10 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/providers/BUILD @@ -0,0 +1,38 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "3rdparty/jvm/com/google/guava", + "3rdparty/jvm/com/squareup/okhttp:okhttp3", + "abdecider/src/main/scala", + "configapi/configapi-core", + "decider/src/main/scala", + "scribelib/marshallers/src/main/scala/com/twitter/scribelib/marshallers", + "servo/decider/src/main/scala", + "snowflake/src/main/scala/com/twitter/snowflake/id", + "src/scala/com/twitter/takedown/util", + "src/thrift/com/twitter/content-health/sensitivemediasettings:sensitivemediasettings-scala", + "src/thrift/com/twitter/gizmoduck:user-thrift-scala", + "src/thrift/com/twitter/search/common:constants-scala", + "src/thrift/com/twitter/spam/rtf:safety-level-scala", + "src/thrift/com/twitter/spam/rtf:safety-result-scala", + "src/thrift/com/twitter/tweetypie:tweet-scala", + "stitch/stitch-core/src/main/scala/com/twitter/stitch", + "visibility/common/src/main/scala/com/twitter/visibility/common/actions", + "visibility/common/src/main/scala/com/twitter/visibility/common/actions/converter/scala", + "visibility/common/src/main/thrift/com/twitter/visibility:action-scala", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/configs", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/params", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/models", + "visibility/lib/src/main/scala/com/twitter/visibility/rules", + "visibility/lib/src/main/scala/com/twitter/visibility/rules/generators", + "visibility/lib/src/main/scala/com/twitter/visibility/util", + "visibility/lib/src/main/thrift/com/twitter/visibility/logging:vf-logging-scala", + "visibility/lib/src/main/thrift/com/twitter/visibility/strato:vf-strato-scala", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/providers/InjectedPolicyProvider.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/providers/InjectedPolicyProvider.scala new file mode 100644 index 000000000..b9eafbbd6 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/providers/InjectedPolicyProvider.scala @@ -0,0 +1,27 @@ +package com.twitter.visibility.rules.providers + +import com.twitter.visibility.configapi.configs.VisibilityDeciderGates +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.rules.MixedVisibilityPolicy +import com.twitter.visibility.rules.RuleBase +import com.twitter.visibility.rules.generators.TweetRuleGenerator + +class InjectedPolicyProvider( + visibilityDeciderGates: VisibilityDeciderGates, + tweetRuleGenerator: TweetRuleGenerator) + extends PolicyProvider { + + private[rules] val policiesForSurface: Map[SafetyLevel, MixedVisibilityPolicy] = + RuleBase.RuleMap.map { + case (safetyLevel, policy) => + ( + safetyLevel, + MixedVisibilityPolicy( + originalPolicy = policy, + additionalTweetRules = tweetRuleGenerator.rulesForSurface(safetyLevel))) + } + + override def policyForSurface(safetyLevel: SafetyLevel): MixedVisibilityPolicy = { + policiesForSurface(safetyLevel) + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/providers/PolicyProvider.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/providers/PolicyProvider.scala new file mode 100644 index 000000000..a39c0a083 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/providers/PolicyProvider.scala @@ -0,0 +1,8 @@ +package com.twitter.visibility.rules.providers + +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.rules.VisibilityPolicy + +trait PolicyProvider { + def policyForSurface(safetyLevel: SafetyLevel): VisibilityPolicy +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/providers/ProvidedEvaluationContext.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/providers/ProvidedEvaluationContext.scala new file mode 100644 index 000000000..76f6899da --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/providers/ProvidedEvaluationContext.scala @@ -0,0 +1,50 @@ +package com.twitter.visibility.rules.providers + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.timelines.configapi.Params +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.rules.EvaluationContext +import com.twitter.visibility.rules.VisibilityPolicy + +sealed abstract class ProvidedEvaluationContext( + visibilityPolicy: VisibilityPolicy, + params: Params, + statsReceiver: StatsReceiver) + extends EvaluationContext( + visibilityPolicy = visibilityPolicy, + params = params, + statsReceiver = statsReceiver) + +object ProvidedEvaluationContext { + + def injectRuntimeRulesIntoEvaluationContext( + evaluationContext: EvaluationContext, + safetyLevel: Option[SafetyLevel] = None, + policyProviderOpt: Option[PolicyProvider] = None + ): ProvidedEvaluationContext = { + (policyProviderOpt, safetyLevel) match { + case (Some(policyProvider), Some(safetyLevel)) => + new InjectedEvaluationContext( + evaluationContext = evaluationContext, + safetyLevel = safetyLevel, + policyProvider = policyProvider) + case (_, _) => new StaticEvaluationContext(evaluationContext) + } + } +} + +private class StaticEvaluationContext( + evaluationContext: EvaluationContext) + extends ProvidedEvaluationContext( + visibilityPolicy = evaluationContext.visibilityPolicy, + params = evaluationContext.params, + statsReceiver = evaluationContext.statsReceiver) + +private class InjectedEvaluationContext( + evaluationContext: EvaluationContext, + safetyLevel: SafetyLevel, + policyProvider: PolicyProvider) + extends ProvidedEvaluationContext( + visibilityPolicy = policyProvider.policyForSurface(safetyLevel), + params = evaluationContext.params, + statsReceiver = evaluationContext.statsReceiver) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/utils/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/rules/utils/BUILD new file mode 100644 index 000000000..75953e740 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/utils/BUILD @@ -0,0 +1,38 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "3rdparty/jvm/com/google/guava", + "3rdparty/jvm/com/squareup/okhttp:okhttp3", + "abdecider/src/main/scala", + "configapi/configapi-core", + "decider/src/main/scala", + "scribelib/marshallers/src/main/scala/com/twitter/scribelib/marshallers", + "servo/decider/src/main/scala", + "snowflake/src/main/scala/com/twitter/snowflake/id", + "src/scala/com/twitter/takedown/util", + "src/thrift/com/twitter/content-health/sensitivemediasettings:sensitivemediasettings-scala", + "src/thrift/com/twitter/gizmoduck:user-thrift-scala", + "src/thrift/com/twitter/search/common:constants-scala", + "src/thrift/com/twitter/spam/rtf:safety-level-scala", + "src/thrift/com/twitter/spam/rtf:safety-result-scala", + "src/thrift/com/twitter/tweetypie:tweet-scala", + "stitch/stitch-core/src/main/scala/com/twitter/stitch", + "visibility/common/src/main/scala/com/twitter/visibility/common/actions", + "visibility/common/src/main/scala/com/twitter/visibility/common/actions/converter/scala", + "visibility/common/src/main/thrift/com/twitter/visibility:action-scala", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/configs", + "visibility/lib/src/main/scala/com/twitter/visibility/configapi/params", + "visibility/lib/src/main/scala/com/twitter/visibility/features", + "visibility/lib/src/main/scala/com/twitter/visibility/models", + "visibility/lib/src/main/scala/com/twitter/visibility/rules", + "visibility/lib/src/main/scala/com/twitter/visibility/rules/providers", + "visibility/lib/src/main/scala/com/twitter/visibility/util", + "visibility/lib/src/main/thrift/com/twitter/visibility/logging:vf-logging-scala", + "visibility/lib/src/main/thrift/com/twitter/visibility/strato:vf-strato-scala", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/rules/utils/ShimUtils.scala b/visibilitylib/src/main/scala/com/twitter/visibility/rules/utils/ShimUtils.scala new file mode 100644 index 000000000..7501d7273 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/rules/utils/ShimUtils.scala @@ -0,0 +1,60 @@ +package com.twitter.visibility.rules.utils + +import com.twitter.visibility.features.Feature +import com.twitter.visibility.features.FeatureMap +import com.twitter.visibility.models.ContentId +import com.twitter.visibility.models.SafetyLevel +import com.twitter.visibility.rules.Filtered +import com.twitter.visibility.rules.Rule +import com.twitter.visibility.rules.RuleBase +import com.twitter.visibility.rules.RuleBase.RuleMap +import com.twitter.visibility.rules.providers.ProvidedEvaluationContext +import com.twitter.visibility.rules.providers.PolicyProvider + +object ShimUtils { + + def preFilterFeatureMap( + featureMap: FeatureMap, + safetyLevel: SafetyLevel, + contentId: ContentId, + evaluationContext: ProvidedEvaluationContext, + policyProviderOpt: Option[PolicyProvider] = None, + ): FeatureMap = { + val safetyLevelRules: Seq[Rule] = policyProviderOpt match { + case Some(policyProvider) => + policyProvider + .policyForSurface(safetyLevel) + .forContentId(contentId) + case _ => RuleMap(safetyLevel).forContentId(contentId) + } + + val afterDisabledRules = + safetyLevelRules.filter(evaluationContext.ruleEnabledInContext) + + val afterMissingFeatureRules = + afterDisabledRules.filter(rule => { + val missingFeatures: Set[Feature[_]] = rule.featureDependencies.collect { + case feature: Feature[_] if !featureMap.contains(feature) => feature + } + if (missingFeatures.isEmpty) { + true + } else { + false + } + }) + + val afterPreFilterRules = afterMissingFeatureRules.filter(rule => { + rule.preFilter(evaluationContext, featureMap.constantMap, null) match { + case Filtered => + false + case _ => + true + } + }) + + val filteredFeatureMap = + RuleBase.removeUnusedFeaturesFromFeatureMap(featureMap, afterPreFilterRules) + + filteredFeatureMap + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/util/BUILD b/visibilitylib/src/main/scala/com/twitter/visibility/util/BUILD new file mode 100644 index 000000000..3bfead4c2 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/util/BUILD @@ -0,0 +1,18 @@ +scala_library( + sources = ["*.scala"], + compiler_option_sets = ["fatal_warnings"], + platform = "java8", + strict_deps = True, + tags = ["bazel-compatible"], + dependencies = [ + "abdecider/src/main/scala", + "decider", + "featureswitches/featureswitches-core/src/main/scala", + "featureswitches/featureswitches-core/src/main/scala/com/twitter/featureswitches/v2/builder", + "stitch/stitch-core", + "twitter-config/yaml", + "util-internal/scribe", + "util/util-logging/src/main/scala", + "util/util-stats/src/main/scala", + ], +) diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/util/DeciderUtil.scala b/visibilitylib/src/main/scala/com/twitter/visibility/util/DeciderUtil.scala new file mode 100644 index 000000000..751066c15 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/util/DeciderUtil.scala @@ -0,0 +1,45 @@ +package com.twitter.visibility.util + +import com.twitter.abdecider.ABDeciderFactory +import com.twitter.abdecider.LoggingABDecider +import com.twitter.decider.Decider +import com.twitter.decider.DeciderFactory +import com.twitter.decider.LocalOverrides +import com.twitter.logging._ + +object DeciderUtil { + val DefaultDeciderPath = "/config/com/twitter/visibility/decider.yml" + + private val zone = Option(System.getProperty("dc")).getOrElse("atla") + val DefaultDeciderOverlayPath: Some[String] = Some( + s"/usr/local/config/overlays/visibility-library/visibility-library/prod/$zone/decider_overlay.yml" + ) + + val DefaultABDeciderPath = "/usr/local/config/abdecider/abdecider.yml" + + def mkDecider( + deciderBasePath: String = DefaultDeciderPath, + deciderOverlayPath: Option[String] = DefaultDeciderOverlayPath, + useLocalDeciderOverrides: Boolean = false, + ): Decider = { + val fileBased = new DeciderFactory(Some(deciderBasePath), deciderOverlayPath)() + if (useLocalDeciderOverrides) { + LocalOverrides.decider("visibility-library").orElse(fileBased) + } else { + fileBased + } + } + + def mkLocalDecider: Decider = mkDecider(deciderOverlayPath = None) + + def mkABDecider( + scribeLogger: Option[Logger], + abDeciderPath: String = DefaultABDeciderPath + ): LoggingABDecider = { + ABDeciderFactory( + abDeciderPath, + Some("production"), + scribeLogger = scribeLogger + ).buildWithLogging() + } +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/util/FeatureSwitchUtil.scala b/visibilitylib/src/main/scala/com/twitter/visibility/util/FeatureSwitchUtil.scala new file mode 100644 index 000000000..f6c0d6953 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/util/FeatureSwitchUtil.scala @@ -0,0 +1,22 @@ +package com.twitter.visibility.util + +import com.twitter.abdecider.ABDecider +import com.twitter.featureswitches.v2.FeatureSwitches +import com.twitter.featureswitches.v2.builder.FeatureSwitchesBuilder +import com.twitter.finagle.stats.StatsReceiver + +object FeatureSwitchUtil { + private val LibraryFeaturesConfigPath = "/features/visibility/main" + private val LimitedActionsFeaturesConfigPath = "/features/visibility-limited-actions/main" + + def mkVisibilityLibraryFeatureSwitches( + abDecider: ABDecider, + statsReceiver: StatsReceiver + ): FeatureSwitches = + FeatureSwitchesBuilder + .createDefault(LibraryFeaturesConfigPath, abDecider, Some(statsReceiver)).build() + + def mkLimitedActionsFeatureSwitches(statsReceiver: StatsReceiver): FeatureSwitches = + FeatureSwitchesBuilder + .createWithNoExperiments(LimitedActionsFeaturesConfigPath, Some(statsReceiver)).build() +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/util/LoggingUtil.scala b/visibilitylib/src/main/scala/com/twitter/visibility/util/LoggingUtil.scala new file mode 100644 index 000000000..aecd21971 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/util/LoggingUtil.scala @@ -0,0 +1,35 @@ +package com.twitter.visibility.util + +import com.twitter.finagle.stats.StatsReceiver +import com.twitter.logging._ + +object LoggingUtil { + + val ExperimentationLog: String = "vf_abdecider" + + def mkDefaultHandlerFactory(statsReceiver: StatsReceiver): () => Handler = { + QueueingHandler( + maxQueueSize = 10000, + handler = ScribeHandler( + category = "client_event", + formatter = BareFormatter, + statsReceiver = statsReceiver.scope("client_event_scribe"), + level = Some(Level.INFO) + ) + ) + } + + def mkDefaultLoggerFactory(statsReceiver: StatsReceiver): LoggerFactory = { + LoggerFactory( + node = ExperimentationLog, + level = Some(Level.INFO), + useParents = false, + handlers = List(mkDefaultHandlerFactory(statsReceiver)) + ) + } + + def mkDefaultLogger(statsReceiver: StatsReceiver): Logger = { + mkDefaultLoggerFactory(statsReceiver)() + } + +} diff --git a/visibilitylib/src/main/scala/com/twitter/visibility/util/NamingUtils.scala b/visibilitylib/src/main/scala/com/twitter/visibility/util/NamingUtils.scala new file mode 100644 index 000000000..0238b6544 --- /dev/null +++ b/visibilitylib/src/main/scala/com/twitter/visibility/util/NamingUtils.scala @@ -0,0 +1,6 @@ +package com.twitter.visibility.util + +object NamingUtils { + def getFriendlyName(a: Any): String = getFriendlyNameFromClass(a.getClass) + def getFriendlyNameFromClass(a: Class[_]): String = a.getSimpleName.stripSuffix("$") +}