1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
|
<?php
declare(strict_types=1);
namespace Doctrine\Tests\ORM\Functional;
use Doctrine\ORM\Query\AST\AggregateExpression;
use Doctrine\ORM\Query\AST\Functions\FunctionNode;
use Doctrine\ORM\Query\AST\PathExpression;
use Doctrine\ORM\Query\Parser;
use Doctrine\ORM\Query\SqlWalker;
use Doctrine\ORM\Query\TokenType;
use Doctrine\Tests\Models\CMS\CmsUser;
use Doctrine\Tests\OrmFunctionalTestCase;
require_once __DIR__ . '/../../TestInit.php';
class CustomFunctionsTest extends OrmFunctionalTestCase
{
protected function setUp(): void
{
$this->useModelSet('cms');
parent::setUp();
}
public function testCustomFunctionDefinedWithCallback(): void
{
$user = new CmsUser();
$user->name = 'Bob';
$user->username = 'Dylan';
$this->_em->persist($user);
$this->_em->flush();
// Instead of defining the function with the class name, we use a callback
$this->_em->getConfiguration()->addCustomStringFunction('FOO', static fn ($funcName) => new NoOp($funcName));
$this->_em->getConfiguration()->addCustomNumericFunction('BAR', static fn ($funcName) => new NoOp($funcName));
$query = $this->_em->createQuery('SELECT u FROM Doctrine\Tests\Models\CMS\CmsUser u'
. ' WHERE FOO(u.name) = \'Bob\''
. ' AND BAR(1) = 1');
$users = $query->getResult();
self::assertCount(1, $users);
self::assertSame($user, $users[0]);
}
public function testCustomFunctionOverride(): void
{
$user = new CmsUser();
$user->name = 'Bob';
$user->username = 'Dylan';
$this->_em->persist($user);
$this->_em->flush();
$this->_em->getConfiguration()->addCustomStringFunction('COUNT', 'Doctrine\Tests\ORM\Functional\CustomCount');
$query = $this->_em->createQuery('SELECT COUNT(DISTINCT u.id) FROM Doctrine\Tests\Models\CMS\CmsUser u');
$usersCount = $query->getSingleScalarResult();
self::assertEquals(1, $usersCount);
}
}
class NoOp extends FunctionNode
{
/** @var PathExpression */
private $field;
public function parse(Parser $parser): void
{
$parser->match(TokenType::T_IDENTIFIER);
$parser->match(TokenType::T_OPEN_PARENTHESIS);
$this->field = $parser->ArithmeticPrimary();
$parser->match(TokenType::T_CLOSE_PARENTHESIS);
}
public function getSql(SqlWalker $sqlWalker): string
{
return $this->field->dispatch($sqlWalker);
}
}
class CustomCount extends FunctionNode
{
private AggregateExpression|null $aggregateExpression = null;
public function parse(Parser $parser): void
{
$this->aggregateExpression = $parser->AggregateExpression();
}
public function getSql(SqlWalker $sqlWalker): string
{
return $this->aggregateExpression->dispatch($sqlWalker);
}
}
|